import io
import os
import json
import jinja2
from PIL import Image
from rdkit.Chem import AllChem, Draw
path = os.path.join(os.path.dirname(__file__), "template")
[docs]def reaction(reactants, products, save_file=None, figure_size=(3, 3), atom_map=False):
"""
Visualize a chemical reaction.
Parameters:
reactants (list of Molecule): list of reactants
products (list of Molecule): list of products
save_file (str, optional): save_file (str, optional): ``png`` file to save visualization.
If not provided, show the figure in window.
figure_size (tuple of int, optional): width and height of the figure
atom_map (bool, optional): visualize atom mapping or not
"""
rxn = AllChem.ChemicalReaction()
for reactant in reactants:
mol = reactant.to_molecule()
if not atom_map:
for atom in mol.GetAtoms():
atom.SetAtomMapNum(0)
rxn.AddReactantTemplate(mol)
for product in products:
mol = product.to_molecule()
if not atom_map:
for atom in mol.GetAtoms():
atom.SetAtomMapNum(0)
rxn.AddProductTemplate(mol)
size = [100 * s for s in figure_size]
img = Draw.ReactionToImage(rxn, size)
if save_file is None:
img.show()
else:
img.save(save_file)
[docs]def highlight(molecule, atoms=None, bonds=None, atom_colors=None, bond_colors=None, save_file=None, figure_size=(3, 3),
atom_map=False):
"""
Visualize a molecule with highlighted atoms or bonds.
Parameters:
molecule (Molecule): molecule to visualize
atoms (list of int): indexes of atoms to highlight
bonds (list of int): indexes of bonds to highlight
atom_colors (tuple or dict): highlight color for atoms.
Can be a tuple of 3 float between 0 and 1, or a dict that maps each index to a different color.
bond_colors (tuple or dict): highlight color for bonds.
Can be a tuple of 3 float between 0 and 1, or a dict that maps each index to a different color.
save_file (str, optional): save_file (str, optional): ``png`` file to save visualization.
If not provided, show the figure in window.
figure_size (tuple of int, optional): width and height of the figure
atom_map (bool, optional): visualize atom mapping or not
"""
if not isinstance(atom_colors, dict):
atom_colors = dict.fromkeys(atoms, atom_colors)
if not isinstance(bond_colors, dict):
bond_colors = dict.fromkeys(bonds, bond_colors)
mol = molecule.to_molecule()
if not atom_map:
for atom in mol.GetAtoms():
atom.SetAtomMapNum(0)
size = [100 * s for s in figure_size]
canvas = Draw.rdMolDraw2D.MolDraw2DCairo(*size)
Draw.rdMolDraw2D.PrepareAndDrawMolecule(canvas, mol, highlightAtoms=atoms, highlightBonds=bonds,
highlightAtomColors=atom_colors, highlightBondColors=bond_colors)
if save_file is None:
stream = io.BytesIO(canvas.GetDrawingText())
img = Image.open(stream)
img.show()
else:
canvas.WriteDrawingText(save_file)
[docs]def echarts(graph, title=None, node_colors=None, edge_colors=None, node_labels=None, relation_labels=None,
node_types=None, type_labels=None, dynamic_size=False, dynamic_width=False, save_file=None):
"""
Visualize a graph in ECharts.
Parameters:
graph (Graph): graph to visualize
title (str, optional): title of the graph
node_colors (dict, optional): specify colors for some nodes.
Each color is either a tuple of 3 integers between 0 and 255, or a hex color code.
edge_colors (dict, optional): specify colors for some edges.
Each color is either a tuple of 3 integers between 0 and 255, or a hex color code.
node_labels (list of str, optional): labels for each node
relation_labels (list of str, optional): labels for each relation
node_types (list of int, optional): type for each node
type_labels (list of str, optional): labels for each node type
dynamic_size (bool, optional): if true, set the size of nodes based on the logarithm of degrees
dynamic_width (bool, optional): if true, set the width of edges based on the edge weights
save_file (str, optional): ``html`` file to save visualization, accompanied by a ``json`` file
"""
if dynamic_size:
symbol_size = (graph.degree_in + graph.degree_out + 2).log()
symbol_size = symbol_size / symbol_size.mean() * 10
symbol_size = symbol_size.tolist()
else:
symbol_size = [10] * graph.num_node
nodes = []
node_colors = node_colors or {}
for i in range(graph.num_node):
node = {
"id": i,
"symbolSize": symbol_size[i],
}
if i in node_colors:
color = node_colors[i]
if isinstance(color, tuple):
color = "rgb%s" % (color,)
node["itemStyle"] = {"color": color}
if node_labels:
node["name"] = node_labels[i]
if node_types:
node["category"] = node_types[i]
nodes.append(node)
if dynamic_width:
width = graph.edge_weight / graph.edge_weight.mean() * 3
width = width.tolist()
else:
width = [3] * graph.num_edge
edges = []
if graph.num_relation:
node_in, node_out, relation = graph.edge_list.t().tolist()
else:
node_in, node_out = graph.edge_list.t().tolist()
relation = None
edge_colors = edge_colors or {}
for i in range(graph.num_edge):
edge = {
"source": node_in[i],
"target": node_out[i],
"lineStyle": {"width": width[i]},
}
if i in edge_colors:
color = edge_colors[i]
if isinstance(color, tuple):
color = "rgb%s" % (color,)
edge["lineStyle"] = {"color": color}
if relation_labels:
edge["value"] = relation_labels[relation[i]]
edges.append(edge)
json_file = os.path.splitext(save_file)[0] + ".json"
data = {
"title": title,
"nodes": nodes,
"edges": edges,
}
if type_labels:
data["categories"] = [{"name": label} for label in type_labels]
variables = {
"data_file": os.path.basename(json_file),
"show_label": "true" if node_labels else "false",
}
with open(os.path.join(path, "echarts.html"), "r") as fin, open(save_file, "w") as fout:
template = jinja2.Template(fin.read())
instance = template.render(variables)
fout.write(instance)
with open(json_file, "w") as fout:
json.dump(data, fout, sort_keys=True, indent=4)