Source code for torchdrug.utils.comm

import os
import multiprocessing
from collections import defaultdict

import torch
from torch import distributed as dist


cpu_group = None
gpu_group = None


[docs]def get_rank(): """ Get the rank of this process in distributed processes. Return 0 for single process case. """ if dist.is_initialized(): return dist.get_rank() if "RANK" in os.environ: return int(os.environ["RANK"]) return 0
[docs]def get_world_size(): """ Get the total number of distributed processes. Return 1 for single process case. """ if dist.is_initialized(): return dist.get_world_size() if "WORLD_SIZE" in os.environ: return int(os.environ["WORLD_SIZE"]) return 1
[docs]def get_group(device): """ Get the process group corresponding to the given device. Parameters: device (torch.device): query device """ group = cpu_group if device.type == "cpu" else gpu_group if group is None: raise ValueError("%s group is not initialized. Use comm.init_process_group() to initialize it" % device.type.upper()) return group
[docs]def init_process_group(backend, init_method=None, **kwargs): """ Initialize CPU and/or GPU process groups. Parameters: backend (str): Communication backend. Use ``nccl`` for GPUs and ``gloo`` for CPUs. init_method (str, optional): URL specifying how to initialize the process group """ global cpu_group global gpu_group dist.init_process_group(backend, init_method, **kwargs) gpu_group = dist.group.WORLD if backend == "nccl": cpu_group = dist.new_group(backend="gloo") else: cpu_group = gpu_group
[docs]def get_cpu_count(): """ Get the number of CPUs on this node. """ return multiprocessing.cpu_count()
def synchronize(): """ Synchronize among all distributed processes. """ if get_world_size() > 1: dist.barrier() def _recursive_read(obj): values = defaultdict(list) sizes = defaultdict(list) if isinstance(obj, torch.Tensor): values[obj.dtype] += [obj.flatten()] sizes[obj.dtype] += [torch.tensor([obj.numel()], device=obj.device)] elif isinstance(obj, dict): for v in obj.values(): child_values, child_sizes = _recursive_read(v) for k, v in child_values.items(): values[k] += v for k, v in child_sizes.items(): sizes[k] += v elif isinstance(obj, list) or isinstance(obj, tuple): for v in obj: child_values, child_sizes = _recursive_read(v) for k, v in child_values.items(): values[k] += v for k, v in child_sizes.items(): sizes[k] += v else: raise ValueError("Unknown type `%s`" % type(obj)) return values, sizes def _recursive_write(obj, values, sizes=None): if isinstance(obj, torch.Tensor): if sizes is None: size = torch.tensor([obj.numel()], device=obj.device) else: s = sizes[obj.dtype] size, s = s.split([1, len(s) - 1]) sizes[obj.dtype] = s v = values[obj.dtype] new_obj, v = v.split([size, v.shape[-1] - size], dim=-1) # compatible with reduce / stack / cat new_obj = new_obj.view(new_obj.shape[:-1] + (-1,) + obj.shape[1:]) values[obj.dtype] = v return new_obj, values elif isinstance(obj, dict): new_obj = {} for k, v in obj.items(): new_obj[k], values = _recursive_write(v, values, sizes) elif isinstance(obj, list) or isinstance(obj, tuple): new_obj = [] for v in obj: new_v, values = _recursive_write(v, values, sizes) new_obj.append(new_v) else: raise ValueError("Unknown type `%s`" % type(obj)) return new_obj, values
[docs]def reduce(obj, op="sum", dst=None): """ Reduce any nested container of tensors. Parameters: obj (Object): any container object. Can be nested list, tuple or dict. op (str, optional): element-wise reduction operator. Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``. dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. Example:: >>> # assume 4 workers >>> rank = comm.get_rank() >>> x = torch.rand(5) >>> obj = {"polynomial": x ** rank} >>> obj = comm.reduce(obj) >>> assert torch.allclose(obj["polynomial"], x ** 3 + x ** 2 + x + 1) """ values = _recursive_read(obj)[0] values = {k: torch.cat(v) for k, v in values.items()} is_mean = op == "mean" if is_mean: op = "sum" op = getattr(dist.ReduceOp, op.upper()) reduced = {} for k, v in values.items(): dtype = v.dtype # NCCL can't solve bool. Cast them to byte if dtype == torch.bool: v = v.byte() group = get_group(v.device) if dst is None: dist.all_reduce(v, op=op, group=group) else: dist.reduce(v, op=op, dst=dst, group=group) if is_mean: v = v / get_world_size() reduced[k] = v.type(dtype) return _recursive_write(obj, reduced)[0]
[docs]def stack(obj, dst=None): """ Stack any nested container of tensors. The new dimension will be added at the 0-th axis. Parameters: obj (Object): any container object. Can be nested list, tuple or dict. dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. Example:: >>> # assume 4 workers >>> rank = comm.get_rank() >>> x = torch.rand(5) >>> obj = {"exponent": x ** rank} >>> obj = comm.stack(obj) >>> truth = torch.stack([torch.ones_like(x), x, x ** 2, x ** 3] >>> assert torch.allclose(obj["exponent"], truth)) """ values = _recursive_read(obj)[0] values = {k: torch.cat(v) for k, v in values.items()} stacked = {} for k, v in values.items(): dtype = v.dtype # NCCL can't solve bool. Cast them to byte if dtype == torch.bool: dtype = torch.uint8 s = torch.zeros(get_world_size(), *v.shape, dtype=dtype, device=v.device) s[get_rank()] = v group = get_group(s.device) if dst is None: dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group) else: dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group) stacked[k] = s.type(v.dtype) return _recursive_write(obj, stacked)[0]
[docs]def cat(obj, dst=None): """ Concatenate any nested container of tensors along the 0-th axis. Parameters: obj (Object): any container object. Can be nested list, tuple or dict. dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. Example:: >>> # assume 4 workers >>> rank = comm.get_rank() >>> rng = torch.arange(10) >>> obj = {"range": rng[rank * (rank + 1) // 2: (rank + 1) * (rank + 2) // 2]} >>> obj = comm.cat(obj) >>> assert torch.allclose(obj["range"], rng) """ values, sizes = _recursive_read(obj) sizes = {k: torch.cat(v) for k, v in sizes.items()} sizes = stack(sizes) cated = {} for k, value in values.items(): size = sizes[k].t().flatten() # sizes[k]: (num_worker, num_obj) dtype = value[0].dtype # NCCL can't solve bool. Cast them to byte if dtype == torch.bool: dtype = torch.uint8 s = torch.zeros(size.sum(), dtype=dtype, device=value[0].device) obj_id = get_rank() world_size = get_world_size() offset = size[:obj_id].sum() for v in value: assert offset + v.numel() <= len(s) s[offset: offset + v.numel()] = v offset += size[obj_id: obj_id + world_size].sum() obj_id += world_size group = get_group(s.device) if dst is None: dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group) else: dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group) cated[k] = s.type(value[0].dtype) sizes = {k: v.sum(dim=0) for k, v in sizes.items()} return _recursive_write(obj, cated, sizes)[0]