torchdrug.utils#
Visualization#
- reaction(reactants, products, save_file=None, figure_size=(3, 3), atom_map=False)[source]#
Visualize a chemical reaction.
- Parameters
reactants (list of Molecule) – list of reactants
products (list of Molecule) – list of products
save_file (str, optional) – save_file (str, optional):
png
file to save visualization. If not provided, show the figure in window.figure_size (tuple of int, optional) – width and height of the figure
atom_map (bool, optional) – visualize atom mapping or not
- highlight(molecule, atoms=None, bonds=None, atom_colors=None, bond_colors=None, save_file=None, figure_size=(3, 3), atom_map=False)[source]#
Visualize a molecule with highlighted atoms or bonds.
- Parameters
molecule (Molecule) – molecule to visualize
atoms (list of int) – indexes of atoms to highlight
bonds (list of int) – indexes of bonds to highlight
atom_colors (tuple or dict) – highlight color for atoms. Can be a tuple of 3 float between 0 and 1, or a dict that maps each index to a different color.
bond_colors (tuple or dict) – highlight color for bonds. Can be a tuple of 3 float between 0 and 1, or a dict that maps each index to a different color.
save_file (str, optional) – save_file (str, optional):
png
file to save visualization. If not provided, show the figure in window.figure_size (tuple of int, optional) – width and height of the figure
atom_map (bool, optional) – visualize atom mapping or not
- echarts(graph, title=None, node_colors=None, edge_colors=None, node_labels=None, relation_labels=None, node_types=None, type_labels=None, dynamic_size=False, dynamic_width=False, save_file=None)[source]#
Visualize a graph in ECharts.
- Parameters
graph (Graph) – graph to visualize
title (str, optional) – title of the graph
node_colors (dict, optional) – specify colors for some nodes. Each color is either a tuple of 3 integers between 0 and 255, or a hex color code.
edge_colors (dict, optional) – specify colors for some edges. Each color is either a tuple of 3 integers between 0 and 255, or a hex color code.
node_labels (list of str, optional) – labels for each node
relation_labels (list of str, optional) – labels for each relation
node_types (list of int, optional) – type for each node
type_labels (list of str, optional) – labels for each node type
dynamic_size (bool, optional) – if true, set the size of nodes based on the logarithm of degrees
dynamic_width (bool, optional) – if true, set the width of edges based on the edge weights
save_file (str, optional) –
html
file to save visualization, accompanied by ajson
file
Auxiliary Torch Functions#
- load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs)[source]#
Load a PyTorch C++ extension just-in-time (JIT). Automatically decide the compilation flags if not specified.
This function performs lazy evaluation and is multi-process-safe.
See torch.utils.cpp_extension.load for more details.
- cat(objs, *args, **kwargs)[source]#
Concatenate a list of nested containers with the same structure.
- sparse_coo_tensor(indices, values, size)[source]#
Construct a sparse COO tensor without index check. Much faster than torch.sparse_coo_tensor.
- Parameters
indices (Tensor) – 2D indices of shape (2, n)
values (Tensor) – values of shape (n,)
size (list) – size of the tensor
Distributed Communication#
- init_process_group(backend, init_method=None, **kwargs)[source]#
Initialize CPU and/or GPU process groups.
- Parameters
backend (str) – Communication backend. Use
nccl
for GPUs andgloo
for CPUs.init_method (str, optional) – URL specifying how to initialize the process group
- get_group(device)[source]#
Get the process group corresponding to the given device.
- Parameters
device (torch.device) – query device
- get_rank()[source]#
Get the rank of this process in distributed processes.
Return 0 for single process case.
- get_world_size()[source]#
Get the total number of distributed processes.
Return 1 for single process case.
- reduce(obj, op='sum', dst=None)[source]#
Reduce any nested container of tensors.
- Parameters
obj (Object) – any container object. Can be nested list, tuple or dict.
op (str, optional) – element-wise reduction operator. Available operators are
sum
,mean
,min
,max
,product
.dst (int, optional) – rank of destination worker. If not specified, broadcast the result to all workers.
Example:
>>> # assume 4 workers >>> rank = comm.get_rank() >>> x = torch.rand(5) >>> obj = {"polynomial": x ** rank} >>> obj = comm.reduce(obj) >>> assert torch.allclose(obj["polynomial"], x ** 3 + x ** 2 + x + 1)
- stack(obj, dst=None)[source]#
Stack any nested container of tensors. The new dimension will be added at the 0-th axis.
- Parameters
obj (Object) – any container object. Can be nested list, tuple or dict.
dst (int, optional) – rank of destination worker. If not specified, broadcast the result to all workers.
Example:
>>> # assume 4 workers >>> rank = comm.get_rank() >>> x = torch.rand(5) >>> obj = {"exponent": x ** rank} >>> obj = comm.stack(obj) >>> truth = torch.stack([torch.ones_like(x), x, x ** 2, x ** 3] >>> assert torch.allclose(obj["exponent"], truth))
- cat(obj, dst=None)[source]#
Concatenate any nested container of tensors along the 0-th axis.
- Parameters
obj (Object) – any container object. Can be nested list, tuple or dict.
dst (int, optional) – rank of destination worker. If not specified, broadcast the result to all workers.
Example:
>>> # assume 4 workers >>> rank = comm.get_rank() >>> rng = torch.arange(10) >>> obj = {"range": rng[rank * (rank + 1) // 2: (rank + 1) * (rank + 2) // 2]} >>> obj = comm.cat(obj) >>> assert torch.allclose(obj["range"], rng)
File Processing#
- download(url, path, save_file=None, md5=None)[source]#
Download a file from the specified url. Skip the downloading step if there exists a file satisfying the given MD5.
- Parameters
url (str) – URL to download
path (str) – path to store the downloaded file
save_file (str, optional) – name of save file. If not specified, infer the file name from the URL.
md5 (str, optional) – MD5 of the file
- extract(zip_file, member=None)[source]#
Extract files from a zip file. Currently,
zip
,gz
,tar.gz
,tar
file types are supported.- Parameters
zip_file (str) – file name
member (str, optional) – extract specific member from the zip file. If not specified, extract all members.
Commandline I/O#
- input_choice(prompt, choice=('y', 'n'))[source]#
Print a prompt on the command line and wait for a choice.
- Parameters
prompt (str) – prompt string
choice (tuple of str, optional) – candidate choices
- capture_rdkit_log()[source]#
Context manager to capture all rdkit loggings.
Example:
>>> with utils.capture_rdkit_log() as log: >>> ... >>> print(log.content)
- long_array(array, truncation=10, display=3)[source]#
Format an array as a string.
- Parameters
array (array_like) – array-like data
truncation (int, optional) – truncate array if its length exceeds this threshold
display (int, optional) – number of elements to display at the beginning and the end in truncated mode
Decorator#
Helper functions#
- copy_args(obj, args=None, ignore=None)[source]#
Copy argument documentation from another function to fill the document of **kwargs in this function.
This class should be applied as a decorator.
- Parameters
obj (object) – object to copy document from
args (tuple of str, optional) – arguments to copy. By default, it copies all argument documentation from
obj
, except those already exist in the current function.ignore (tuple of str, optional) – arguments to ignore