torchdrug.utils

Visualization

reaction(reactants, products, save_file=None, figure_size=3, 3, atom_map=False)[source]

Visualize a chemical reaction.

Parameters
  • reactants (list of Molecule) – list of reactants

  • products (list of Molecule) – list of products

  • save_file (str, optional) – save_file (str, optional): png file to save visualization. If not provided, show the figure in window.

  • figure_size (tuple of int, optional) – width and height of the figure

  • atom_map (bool, optional) – visualize atom mapping or not

highlight(molecule, atoms=None, bonds=None, atom_colors=None, bond_colors=None, save_file=None, figure_size=3, 3, atom_map=False)[source]

Visualize a molecule with highlighted atoms or bonds.

Parameters
  • molecule (Molecule) – molecule to visualize

  • atoms (list of int) – indexes of atoms to highlight

  • bonds (list of int) – indexes of bonds to highlight

  • atom_colors (tuple or dict) – highlight color for atoms. Can be a tuple of 3 float between 0 and 1, or a dict that maps each index to a different color.

  • bond_colors (tuple or dict) – highlight color for bonds. Can be a tuple of 3 float between 0 and 1, or a dict that maps each index to a different color.

  • save_file (str, optional) – save_file (str, optional): png file to save visualization. If not provided, show the figure in window.

  • figure_size (tuple of int, optional) – width and height of the figure

  • atom_map (bool, optional) – visualize atom mapping or not

echarts(graph, title=None, node_colors=None, edge_colors=None, node_labels=None, relation_labels=None, node_types=None, type_labels=None, dynamic_size=False, dynamic_width=False, save_file=None)[source]

Visualize a graph in ECharts.

Parameters
  • graph (Graph) – graph to visualize

  • title (str, optional) – title of the graph

  • node_colors (dict, optional) – specify colors for some nodes. Each color is either a tuple of 3 integers between 0 and 255, or a hex color code.

  • edge_colors (dict, optional) – specify colors for some edges. Each color is either a tuple of 3 integers between 0 and 255, or a hex color code.

  • node_labels (list of str, optional) – labels for each node

  • relation_labels (list of str, optional) – labels for each relation

  • node_types (list of int, optional) – type for each node

  • type_labels (list of str, optional) – labels for each node type

  • dynamic_size (bool, optional) – if true, set the size of nodes based on the logarithm of degrees

  • dynamic_width (bool, optional) – if true, set the width of edges based on the edge weights

  • save_file (str, optional) – html file to save visualization, accompanied by a json file

Auxiliary Torch Functions

load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs)[source]

Load a PyTorch C++ extension just-in-time (JIT). Automatically decide the compilation flags if not specified.

This function performs lazy evaluation and is multi-process-safe.

See torch.utils.cpp_extension.load for more details.

cpu(obj, *args, **kwargs)[source]

Transfer any nested conatiner of tensors to CPU.

cuda(obj, *args, **kwargs)[source]

Transfer any nested conatiner of tensors to CUDA.

detach(obj)[source]

Detach tensors in any nested conatiner.

clone(obj, *args, **kwargs)[source]

Clone tensors in any nested conatiner.

mean(obj, *args, **kwargs)[source]

Compute mean of tensors in any nested container.

cat(objs, *args, **kwargs)[source]

Concatenate a list of nested containers with the same structure.

stack(objs, *args, **kwargs)[source]

Stack a list of nested containers with the same structure.

sparse_coo_tensor(indices, values, size)[source]

Construct a sparse COO tensor without index check. Much faster than torch.sparse_coo_tensor.

Parameters
  • indices (Tensor) – 2D indices of shape (2, n)

  • values (Tensor) – values of shape (n,)

  • size (list) – size of the tensor

Distributed Communication

init_process_group(backend, init_method=None, **kwargs)[source]

Initialize CPU and/or GPU process groups.

Parameters
  • backend (str) – Communication backend. Use nccl for GPUs and gloo for CPUs.

  • init_method (str, optional) – URL specifying how to initialize the process group

get_cpu_count()[source]

Get the number of CPUs on this node.

get_group(device)[source]

Get the process group corresponding to the given device.

Parameters

device (torch.device) – query device

get_rank()[source]

Get the rank of this process in distributed processes.

Return 0 for single process case.

get_world_size()[source]

Get the total number of distributed processes.

Return 1 for single process case.

reduce(obj, op='sum', dst=None)[source]

Reduce any nested container of tensors.

Parameters
  • obj (Object) – any container object. Can be nested list, tuple or dict.

  • op (str, optional) – element-wise reduction operator. Available operators are sum, mean, min, max, product.

  • dst (int, optional) – rank of destination worker. If not specified, broadcast the result to all workers.

Example:

>>> # assume 4 workers
>>> rank = comm.get_rank()
>>> x = torch.rand(5)
>>> obj = {"polynomial": x ** rank}
>>> obj = comm.reduce(obj)
>>> assert torch.allclose(obj["polynomial"], x ** 3 + x ** 2 + x + 1)
stack(obj, dst=None)[source]

Stack any nested container of tensors. The new dimension will be added at the 0-th axis.

Parameters
  • obj (Object) – any container object. Can be nested list, tuple or dict.

  • dst (int, optional) – rank of destination worker. If not specified, broadcast the result to all workers.

Example:

>>> # assume 4 workers
>>> rank = comm.get_rank()
>>> x = torch.rand(5)
>>> obj = {"exponent": x ** rank}
>>> obj = comm.stack(obj)
>>> truth = torch.stack([torch.ones_like(x), x, x ** 2, x ** 3]
>>> assert torch.allclose(obj["exponent"], truth))
cat(obj, dst=None)[source]

Concatenate any nested container of tensors along the 0-th axis.

Parameters
  • obj (Object) – any container object. Can be nested list, tuple or dict.

  • dst (int, optional) – rank of destination worker. If not specified, broadcast the result to all workers.

Example:

>>> # assume 4 workers
>>> rank = comm.get_rank()
>>> rng = torch.arange(10)
>>> obj = {"range": rng[rank * (rank + 1) // 2: (rank + 1) * (rank + 2) // 2]}
>>> obj = comm.cat(obj)
>>> assert torch.allclose(obj["range"], rng)

File Processing

download(url, path, save_file=None, md5=None)[source]

Download a file from the specified url. Skip the downloading step if there exists a file satisfying the given MD5.

Parameters
  • url (str) – URL to download

  • path (str) – path to store the downloaded file

  • save_file (str, optional) – name of save file. If not specified, infer the file name from the URL.

  • md5 (str, optional) – MD5 of the file

extract(zip_file, member=None)[source]

Extract files from a zip file. Currently, zip, gz, tar.gz, tar file types are supported.

Parameters

member (str, optional) – extract a specific member from the zip file. If not specified, extract all members.

compute_md5(file_name, chunk_size=65536)[source]

Compute MD5 of the file.

Parameters
  • file_name (str) – file name

  • chunk_size (int, optional) – chunk size for reading large files

get_line_count(file_name, chunk_size=8388608)[source]

Get the number of lines in a file.

Parameters
  • file_name (str) – file name

  • chunk_size (int, optional) – chunk size for reading large files

Commandline I/O

input_choice(prompt, choice='y', 'n')[source]

Print a prompt on the command line and wait for a choice.

Parameters
  • prompt (str) – prompt string

  • choice (tuple of str, optional) – candidate choices

literal_eval(string)[source]

Evaluate an expression into a Python literal structure.

no_rdkit_log()[source]

Context manager to suppress all rdkit loggings.

capture_rdkit_log()[source]

Context manager to capture all rdkit loggings.

Example:

>>> with utils.capture_rdkit_log() as log:
>>>     ...
>>> print(log.content)
long_array(array, truncation=10, display=3)[source]

Format an array as a string.

Parameters
  • array (array_like) – array-like data

  • truncation (int, optional) – truncate array if its length exceeds this threshold

  • display (int, optional) – number of elements to display at the beginning and the end in truncated mode

time(seconds)[source]

Format time as a string.

Parameters

seconds (float) – time in seconds

Decorator

cached_property(func)[source]

Cache the property once computed.

Helper functions

copy_args(obj, args=None, ignore=None)[source]

Copy argument documentation from another function to fill the document of **kwargs in this function.

This class should be applied as a decorator.

Parameters
  • obj (object) – object to copy document from

  • args (tuple of str, optional) – arguments to copy. By default, it copies all argument documentation from obj, except those already exist in the current function.

  • ignore (tuple of str, optional) – arguments to ignore