Source code for torchdrug.tasks.retrosynthesis

import inspect
from collections import deque

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data as torch_data
from torch_scatter import scatter_max, scatter_add

from torchdrug import core, tasks, data, metrics, transforms
from torchdrug.layers import functional
from torchdrug.core import Registry as R
from torchdrug import layers

import logging
logger = logging.getLogger(__name__)


[docs]@R.register("tasks.CenterIdentification") class CenterIdentification(tasks.Task, core.Configurable): """ Reaction center identification task. This class is a part of retrosynthesis prediction. Parameters: model (nn.Module): graph representation model feature (str or list of str, optional): additional features for prediction. Available features are reaction: type of the reaction graph: graph representation of the product atom: original atom feature bond: original bond feature num_mlp_layer (int, optional): number of MLP layers """ _option_members = {"feature"} def __init__(self, model, feature=("reaction", "graph", "atom", "bond"), num_mlp_layer=2): super(CenterIdentification, self).__init__() self.model = model self.num_mlp_layer = num_mlp_layer self.feature = feature def preprocess(self, train_set, valid_set, test_set): reaction_types = set() bond_types = set() for sample in train_set: reaction_types.add(sample["reaction"]) for graph in sample["graph"]: bond_types.update(graph.edge_list[:, 2].tolist()) self.num_reaction = len(reaction_types) self.num_relation = len(bond_types) node_feature_dim = train_set[0]["graph"][0].node_feature.shape[-1] edge_feature_dim = train_set[0]["graph"][0].edge_feature.shape[-1] node_dim = self.model.output_dim edge_dim = 0 graph_dim = 0 for _feature in sorted(self.feature): if _feature == "reaction": graph_dim += self.num_reaction elif _feature == "graph": graph_dim += self.model.output_dim elif _feature == "atom": node_dim += node_feature_dim elif _feature == "bond": edge_dim += edge_feature_dim else: raise ValueError("Unknown feature `%s`" % _feature) node_dim += graph_dim # inherit graph features edge_dim += node_dim * 2 # inherit node features hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1) self.edge_mlp = layers.MLP(edge_dim, hidden_dims + [1]) self.node_mlp = layers.MLP(node_dim, hidden_dims + [1]) def forward(self, batch): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred = self.predict(batch, all_loss, metric) target = self.target(batch) metric.update(self.evaluate(pred, target)) target, size = target target = functional.variadic_max(target, size)[1] loss = functional.variadic_cross_entropy(pred, target, size) name = tasks._get_criterion_name("ce") metric[name] = loss all_loss += loss return all_loss, metric def _collate(self, edge_data, node_data, graph): new_data = torch.zeros(len(edge_data) + len(node_data), *edge_data.shape[1:], dtype=edge_data.dtype, device=edge_data.device) num_cum_xs = graph.num_cum_edges + graph.num_cum_nodes num_xs = graph.num_edges + graph.num_nodes starts = num_cum_xs - num_xs ends = starts + graph.num_edges index = functional.multi_slice_mask(starts, ends, num_cum_xs[-1]) new_data[index] = edge_data new_data[~index] = node_data return new_data def target(self, batch): reactant, product = batch["graph"] graph = product.directed() target = self._collate(graph.edge_label, graph.node_label, graph) size = graph.num_edges + graph.num_nodes return target, size def predict(self, batch, all_loss=None, metric=None): reactant, product = batch["graph"] output = self.model(product, product.node_feature.float(), all_loss, metric) graph = product.directed() node_feature = [output["node_feature"]] edge_feature = [] graph_feature = [] for _feature in sorted(self.feature): if _feature == "reaction": reaction_feature = torch.zeros(len(graph), self.num_reaction, dtype=torch.float32, device=self.device) reaction_feature.scatter_(1, batch["reaction"].unsqueeze(-1), 1) graph_feature.append(reaction_feature) elif _feature == "graph": graph_feature.append(output["graph_feature"]) elif _feature == "atom": node_feature.append(graph.node_feature.float()) elif _feature == "bond": edge_feature.append(graph.edge_feature.float()) else: raise ValueError("Unknown feature `%s`" % _feature) graph_feature = torch.cat(graph_feature, dim=-1) # inherit graph features node_feature.append(graph_feature[graph.node2graph]) node_feature = torch.cat(node_feature, dim=-1) # inherit node features edge_feature.append(node_feature[graph.edge_list[:, :2]].flatten(1)) edge_feature = torch.cat(edge_feature, dim=-1) edge_pred = self.edge_mlp(edge_feature).squeeze(-1) node_pred = self.node_mlp(node_feature).squeeze(-1) pred = self._collate(edge_pred, node_pred, graph) return pred def evaluate(self, pred, target): target, size = target metric = {} target = functional.variadic_max(target, size)[1] accuracy = metrics.variadic_accuracy(pred, target, size).mean() name = tasks._get_metric_name("acc") metric[name] = accuracy return metric
[docs] @torch.no_grad() def predict_synthon(self, batch, k=1): """ Predict top-k synthons from target molecules. Parameters: batch (dict): batch of target molecules k (int, optional): return top-k results Returns: list of dict: top k records. Each record is a batch dict of keys ``synthon``, ``num_synthon``, ``reaction_center``, ``log_likelihood`` and ``reaction``. """ pred = self.predict(batch) target, size = self.target(batch) logp = functional.variadic_log_softmax(pred, size) reactant, product = batch["graph"] graph = product.directed() with graph.graph(): graph.product_id = torch.arange(len(graph), device=self.device) graph = graph.repeat_interleave(k) reaction = batch["reaction"].repeat_interleave(k) with graph.graph(): graph.split_id = torch.arange(k, device=self.device).repeat(len(graph) // k) logp, center_topk = functional.variadic_topk(logp, size, k) logp = logp.flatten() center_topk = center_topk.flatten() is_edge = center_topk < graph.num_edges node_index = center_topk + graph.num_cum_nodes - graph.num_nodes - graph.num_edges edge_index = center_topk + graph.num_cum_edges - graph.num_edges center_topk_shifted = torch.cat([-torch.ones(1, dtype=torch.long, device=self.device), center_topk[:-1]]) product_id_shifted = torch.cat([-torch.ones(1, dtype=torch.long, device=self.device), graph.product_id[:-1]]) is_duplicate = (center_topk == center_topk_shifted) & (graph.product_id == product_id_shifted) node_index = node_index[~is_edge] edge_index = edge_index[is_edge] edge_mask = ~functional.as_mask(edge_index, graph.num_edge) reaction_center = torch.zeros(len(graph), 2, dtype=torch.long, device=self.device) reaction_center[is_edge] = graph.atom_map[graph.edge_list[edge_index, :2]] reaction_center[~is_edge, 0] = graph.atom_map[node_index] # remove the edges from products graph = graph.edge_mask(edge_mask) graph = graph[~is_duplicate] reaction_center = reaction_center[~is_duplicate] logp = logp[~is_duplicate] reaction = reaction[~is_duplicate] synthon, num_synthon = graph.connected_components() synthon = synthon.undirected() # (< num_graph * k) result = { "synthon": synthon, "num_synthon": num_synthon, "reaction_center": reaction_center, "log_likelihood": logp, "reaction": reaction, } return result
class RandomBFSOrder(object): def __call__(self, item): assert hasattr(item["graph"][0], "reaction_center") reactant, synthon = item["graph"] edge_list = reactant.edge_list[:, :2].tolist() neighbor = [[] for _ in range(reactant.num_node)] for h, t in edge_list: neighbor[h].append(t) depth = [-1] * reactant.num_node # select a mapped atom as BFS root reactant2id = reactant.atom_map id2synthon = -torch.ones(synthon.atom_map.max() + 1, dtype=torch.long, device=synthon.device) id2synthon[synthon.atom_map] = torch.arange(synthon.num_node, device=synthon.device) reactant2synthon = id2synthon[reactant2id] candidate = (reactant2synthon != -1).nonzero().squeeze(-1) i = candidate[torch.randint(len(candidate), (1,))].item() queue = deque([i]) depth[i] = 0 order = [] while queue: h = queue.popleft() order.append(h) for t in neighbor[h]: if depth[t] == -1: depth[t] = depth[h] + 1 queue.append(t) reactant = reactant.subgraph(order) if reactant.num_edge > 0: node_index = reactant.edge_list[:, :2] node_large = node_index.max(dim=-1)[0] node_small = node_index.min(dim=-1)[0] undirected_edge_id = node_large * (node_large + 1) + node_small undirected_edge_id = undirected_edge_id * 2 + (node_index[:, 0] > node_index[:, 1]) # rearrange edges into autoregressive order edge_order = undirected_edge_id.argsort() reactant = reactant.edge_mask(edge_order) assert hasattr(reactant, "reaction_center") item = item.copy() item["graph"] = (reactant, synthon) return item
[docs]@R.register("tasks.SynthonCompletion") class SynthonCompletion(tasks.Task, core.Configurable): """ Synthon completion task. This class is a part of retrosynthesis prediction. Parameters: model (nn.Module): graph representation model feature (str or list of str, optional): additional features for prediction. Available features are reaction: type of the reaction graph: graph representation of the synthon atom: original atom feature num_mlp_layer (int, optional): number of MLP layers """ _option_members = {"feature"} def __init__(self, model, feature=("reaction", "graph", "atom"), num_mlp_layer=2): super(SynthonCompletion, self).__init__() self.model = model self.num_mlp_layer = num_mlp_layer self.feature = feature self.input_linear = nn.Linear(2, self.model.input_dim) def preprocess(self, train_set, valid_set, test_set): reaction_types = set() atom_types = set() bond_types = set() for sample in train_set: reaction_types.add(sample["reaction"]) for graph in sample["graph"]: atom_types.update(graph.atom_type.tolist()) bond_types.update(graph.edge_list[:, 2].tolist()) # TODO: only for fast debugging, to remove # atom_types = torch.tensor([5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 29, 30, 34, 35, 50, 53]) # bond_types = torch.tensor([0, 1, 2]) atom_types = torch.tensor(sorted(atom_types)) atom2id = -torch.ones(atom_types.max() + 1, dtype=torch.long) atom2id[atom_types] = torch.arange(len(atom_types)) self.register_buffer("id2atom", atom_types) self.register_buffer("atom2id", atom2id) self.num_reaction = len(reaction_types) self.num_atom_type = len(atom_types) self.num_bond_type = len(bond_types) node_feature_dim = train_set[0]["graph"][0].node_feature.shape[-1] if isinstance(train_set, torch_data.Subset): dataset = train_set.dataset else: dataset = train_set dataset.transform = transforms.Compose([ dataset.transform, RandomBFSOrder(), ]) sig = inspect.signature(data.PackedMolecule.from_molecule) keys = set(sig.parameters.keys()) kwargs = dataset.config_dict() feature_kwargs = {} for k, v in kwargs.items(): if k in keys: feature_kwargs[k] = v self.feature_kwargs = feature_kwargs node_dim = self.model.output_dim edge_dim = 0 graph_dim = 0 for _feature in sorted(self.feature): if _feature == "reaction": graph_dim += self.num_reaction elif _feature == "graph": graph_dim += self.model.output_dim elif _feature == "atom": node_dim += node_feature_dim else: raise ValueError("Unknown feature `%s`" % _feature) self.new_atom_feature = nn.Embedding(self.num_atom_type, node_dim) node_dim += graph_dim # inherit graph features edge_dim += node_dim * 2 # inherit node features hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1) self.node_in_mlp = layers.MLP(node_dim, hidden_dims + [1]) self.node_out_mlp = layers.MLP(edge_dim, hidden_dims + [1]) self.edge_mlp = layers.MLP(edge_dim, hidden_dims + [1]) self.bond_mlp = layers.MLP(edge_dim, hidden_dims + [self.num_bond_type]) self.stop_mlp = layers.MLP(graph_dim, hidden_dims + [1]) def _update_molecule_feature(self, graphs): # This function is very slow graphs = graphs.ion_to_molecule() mols = graphs.to_molecule(ignore_error=True) valid = [mol is not None for mol in mols] valid = torch.tensor(valid, device=graphs.device) new_graphs = type(graphs).from_molecule(mols, **self.feature_kwargs) node_feature = torch.zeros(graphs.num_node, *new_graphs.node_feature.shape[1:], dtype=new_graphs.node_feature.dtype, device=graphs.device) edge_feature = torch.zeros(graphs.num_edge, *new_graphs.edge_feature.shape[1:], dtype=new_graphs.edge_feature.dtype, device=graphs.device) bond_type = torch.zeros_like(graphs.bond_type) node_mask = valid[graphs.node2graph] edge_mask = valid[graphs.edge2graph] node_feature[node_mask] = new_graphs.node_feature.to(device=graphs.device) edge_feature[edge_mask] = new_graphs.edge_feature.to(device=graphs.device) bond_type[edge_mask] = new_graphs.bond_type.to(device=graphs.device) with graphs.node(): graphs.node_feature = node_feature with graphs.edge(): graphs.edge_feature = edge_feature graphs.bond_type = bond_type return graphs, valid @torch.no_grad() def _all_prefix_slice(self, num_xs, lengths=None): # extract a bunch of slices that correspond to the following num_repeat * n masks # ------ repeat 0 ----- # graphs[0]: [0, 0, ..., 0] # ... # graphs[-1]: [0, 0, ..., 0] # ------ repeat 1 ----- # graphs[0]: [1, 0, ..., 0] # ... # graphs[-1]: [1, 0, ..., 0] # ... # ------ repeat -1 ----- # graphs[0]: [1, ..., 1, 0] # ... # graphs[-1]: [1, ..., 1, 0] num_cum_xs = num_xs.cumsum(0) starts = num_cum_xs - num_xs if lengths is None: num_max_x = num_xs.max().item() lengths = torch.arange(0, num_max_x, 2, device=num_xs.device) pack_offsets = torch.arange(len(lengths), device=num_xs.device) * num_cum_xs[-1] # starts, lengths, ends: (num_repeat, num_graph) starts = starts.unsqueeze(0) + pack_offsets.unsqueeze(-1) valid = lengths.unsqueeze(-1) <= num_xs.unsqueeze(0) - 2 lengths = torch.min(lengths.unsqueeze(-1), num_xs.unsqueeze(0) - 2).clamp(0) ends = starts + lengths starts = starts.flatten() ends = ends.flatten() valid = valid.flatten() return starts, ends, valid @torch.no_grad() def _get_reaction_feature(self, reactant, synthon): def get_edge_map(graph, num_nodes): node_in, node_out = graph.edge_list.t()[:2] node_in2id = graph.atom_map[node_in] node_out2id = graph.atom_map[node_out] edge_map = node_in2id * num_nodes[graph.edge2graph] + node_out2id # edges containing any unmapped node is considered to be unmapped edge_map[(node_in2id == 0) | (node_out2id == 0)] = 0 return edge_map def get_mapping(reactant_x, synthon_x, reactant_x2graph, synthon_x2graph): num_xs = scatter_max(reactant_x, reactant_x2graph)[0] num_xs = num_xs.clamp(0) + 1 num_cum_xs = num_xs.cumsum(0) offset = num_cum_xs - num_xs reactant2id = reactant_x + offset[reactant_x2graph] synthon2id = synthon_x + offset[synthon_x2graph] assert synthon2id.min() > 0 id2synthon = -torch.ones(num_cum_xs[-1], dtype=torch.long, device=self.device) id2synthon[synthon2id] = torch.arange(len(synthon2id), device=self.device) reactant2synthon = id2synthon[reactant2id] return reactant2synthon # reactant & synthon may have different number of nodes # reactant.num_nodes >= synthon.num_nodes assert (reactant.num_nodes >= synthon.num_nodes).all() reactant_edge_map = get_edge_map(reactant, reactant.num_nodes) synthon_edge_map = get_edge_map(synthon, reactant.num_nodes) node_r2s = get_mapping(reactant.atom_map, synthon.atom_map, reactant.node2graph, synthon.node2graph) edge_r2s = get_mapping(reactant_edge_map, synthon_edge_map, reactant.edge2graph, synthon.edge2graph) is_new_node = node_r2s == -1 is_new_edge = edge_r2s == -1 is_modified_edge = (edge_r2s != -1) & (reactant.bond_type != synthon.bond_type[edge_r2s]) is_reaction_center = (reactant.atom_map > 0) & \ (reactant.atom_map.unsqueeze(-1) == reactant.reaction_center[reactant.node2graph]).any(dim=-1) return node_r2s, edge_r2s, is_new_node, is_new_edge, is_modified_edge, is_reaction_center @torch.no_grad() def all_edge(self, reactant, synthon): graph = reactant.clone() node_r2s, edge_r2s, is_new_node, is_new_edge, is_modified_edge, is_reaction_center = \ self._get_reaction_feature(reactant, synthon) with graph.node(): graph.node_r2s = node_r2s graph.is_new_node = is_new_node graph.is_reaction_center = is_reaction_center with graph.edge(): graph.edge_r2s = edge_r2s graph.is_new_edge = is_new_edge graph.is_modified_edge = is_modified_edge starts, ends, valid = self._all_prefix_slice(reactant.num_edges) num_repeat = len(starts) // len(reactant) graph = graph.repeat(num_repeat) # autoregressive condition range for each sample condition_mask = functional.multi_slice_mask(starts, ends, graph.num_edge) # special case: end == graph.num_edge. In this case, valid is always false assert ends.max() <= graph.num_edge ends = ends.clamp(0, graph.num_edge - 1) node_in, node_out, bond_target = graph.edge_list[ends].t() # modified edges which don't appear in conditions should keep their old bond types # i.e. bond types in synthons unmodified = ~condition_mask & graph.is_modified_edge unmodified = unmodified.nonzero().squeeze(-1) assert not (graph.bond_type[unmodified] == synthon.bond_type[graph.edge_r2s[unmodified]]).any() graph.edge_list[unmodified, 2] = synthon.edge_list[graph.edge_r2s[unmodified], 2] reverse_target = graph.edge_list[ends][:, [1, 0, 2]] is_reverse_target = (graph.edge_list == reverse_target[graph.edge2graph]).all(dim=-1) # keep edges that exist in the synthon # remove the reverse of new target edges edge_mask = (condition_mask & ~is_reverse_target) | ~graph.is_new_edge atom_in = graph.atom_type[node_in] atom_out = graph.atom_type[node_out] # keep one supervision for undirected edges # remove samples that try to predict existing edges valid &= (node_in < node_out) & (graph.is_new_edge[ends] | graph.is_modified_edge[ends]) graph = graph.edge_mask(edge_mask) # sanitize the molecules # this will change atom index, so we manually remap the target nodes compact_mapping = -torch.ones(graph.num_node, dtype=torch.long, device=self.device) node_mask = graph.degree_in + graph.degree_out > 0 # special case: for graphs without any edge, the first node should be kept index = torch.arange(graph.num_node, device=self.device) single_node_mask = (graph.num_edges == 0)[graph.node2graph] & \ (index == (graph.num_cum_nodes - graph.num_nodes)[graph.node2graph]) node_index = (node_mask | single_node_mask).nonzero().squeeze(-1) compact_mapping[node_index] = torch.arange(len(node_index), device=self.device) node_in = compact_mapping[node_in] node_out = compact_mapping[node_out] graph = graph.subgraph(node_index) node_in_target = node_in - graph.num_cum_nodes + graph.num_nodes assert (node_in_target[valid] < graph.num_nodes[valid]).all() and (node_in_target[valid] >= 0).all() # node2 might be a new node node_out_target = torch.where(node_out == -1, self.atom2id[atom_out] + graph.num_nodes, node_out - graph.num_cum_nodes + graph.num_nodes) stop_target = torch.zeros(len(node_in_target), device=self.device) graph = graph[valid] node_in_target = node_in_target[valid] node_out_target = node_out_target[valid] bond_target = bond_target[valid] stop_target = stop_target[valid] assert (graph.num_edges % 2 == 0).all() # node / edge features may change because we mask some nodes / edges graph, feature_valid = self._update_molecule_feature(graph) return graph[feature_valid], node_in_target[feature_valid], node_out_target[feature_valid], \ bond_target[feature_valid], stop_target[feature_valid] @torch.no_grad() def all_stop(self, reactant, synthon): graph = reactant.clone() node_r2s, edge_r2s, is_new_node, is_new_edge, is_modified_edge, is_reaction_center = \ self._get_reaction_feature(reactant, synthon) with graph.node(): graph.node_r2s = node_r2s graph.is_new_node = is_new_node graph.is_reaction_center = is_reaction_center with graph.edge(): graph.edge_r2s = edge_r2s graph.is_new_edge = is_new_edge graph.is_modified_edge = is_modified_edge node_in_target = torch.zeros(len(graph), dtype=torch.long, device=self.device) node_out_target = torch.zeros_like(node_in_target) bond_target = torch.zeros_like(node_in_target) stop_target = torch.ones(len(graph), device=self.device) # keep consistent with other training data graph, feature_valid = self._update_molecule_feature(graph) return graph[feature_valid], node_in_target[feature_valid], node_out_target[feature_valid], \ bond_target[feature_valid], stop_target[feature_valid] def forward(self, batch): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred, target = self.predict_and_target(batch, all_loss, metric) node_in_pred, node_out_pred, bond_pred, stop_pred = pred node_in_target, node_out_target, bond_target, stop_target, size = target loss = functional.variadic_cross_entropy(node_in_pred, node_in_target, size, reduction="none") loss = functional.masked_mean(loss, stop_target == 0) metric["node in ce loss"] = loss all_loss += loss loss = functional.variadic_cross_entropy(node_out_pred, node_out_target, size, reduction="none") loss = functional.masked_mean(loss, stop_target == 0) metric["node out ce loss"] = loss all_loss += loss loss = F.cross_entropy(bond_pred, bond_target, reduction="none") loss = functional.masked_mean(loss, stop_target == 0) metric["bond ce loss"] = loss all_loss += loss # Do we need to balance stop pred? loss = F.binary_cross_entropy_with_logits(stop_pred, stop_target) metric["stop bce loss"] = loss all_loss += loss metric["total loss"] = all_loss metric.update(self.evaluate(pred, target)) return all_loss, metric def evaluate(self, pred, target): node_in_pred, node_out_pred, bond_pred, stop_pred = pred node_in_target, node_out_target, bond_target, stop_target, size = target metric = {} node_in_acc = metrics.variadic_accuracy(node_in_pred, node_in_target, size) accuracy = functional.masked_mean(node_in_acc, stop_target == 0) metric["node in accuracy"] = accuracy node_out_acc = metrics.variadic_accuracy(node_out_pred, node_out_target, size) accuracy = functional.masked_mean(node_out_acc, stop_target == 0) metric["node out accuracy"] = accuracy bond_acc = (bond_pred.argmax(-1) == bond_target).float() accuracy = functional.masked_mean(bond_acc, stop_target == 0) metric["bond accuracy"] = accuracy stop_acc = ((stop_pred > 0.5) == (stop_target > 0.5)).float() metric["stop accuracy"] = stop_acc.mean() total_acc = (node_in_acc > 0.5) & (node_out_acc > 0.5) & (bond_acc > 0.5) & (stop_acc > 0.5) total_acc = torch.where(stop_target == 0, total_acc, stop_acc > 0.5).float() metric["total accuracy"] = total_acc.mean() return metric def _cat(self, graphs): for i, graph in enumerate(graphs): if not isinstance(graph, data.PackedGraph): graphs[i] = graph.pack([graph]) edge_list = torch.cat([graph.edge_list for graph in graphs]) pack_num_nodes = torch.stack([graph.num_node for graph in graphs]) pack_num_edges = torch.stack([graph.num_edge for graph in graphs]) pack_num_cum_edges = pack_num_edges.cumsum(0) graph_index = pack_num_cum_edges < len(edge_list) pack_offsets = scatter_add(pack_num_nodes[graph_index], pack_num_cum_edges[graph_index], dim_size=len(edge_list)) pack_offsets = pack_offsets.cumsum(0) edge_list[:, :2] += pack_offsets.unsqueeze(-1) offsets = torch.cat([graph._offsets for graph in graphs]) + pack_offsets edge_weight = torch.cat([graph.edge_weight for graph in graphs]) num_nodes = torch.cat([graph.num_nodes for graph in graphs]) num_edges = torch.cat([graph.num_edges for graph in graphs]) num_relation = graphs[0].num_relation assert all(graph.num_relation == num_relation for graph in graphs) # only keep attributes that exist in all graphs keys = set(graphs[0].meta_dict.keys()) for graph in graphs: keys = keys.intersection(graph.meta_dict.keys()) meta_dict = {k: graphs[0].meta_dict[k] for k in keys} data_dict = {} for k in keys: data_dict[k] = torch.cat([graph.data_dict[k] for graph in graphs]) return type(graphs[0])(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation, offsets=offsets, meta_dict=meta_dict, **data_dict) def target(self, batch): reactant, synthon = batch["graph"] graph1, node_in_target1, node_out_target1, bond_target1, stop_target1 = self.all_edge(reactant, synthon) graph2, node_in_target2, node_out_target2, bond_target2, stop_target2 = self.all_stop(reactant, synthon) node_in_target = torch.cat([node_in_target1, node_in_target2]) node_out_target = torch.cat([node_out_target1, node_out_target2]) bond_target = torch.cat([bond_target1, bond_target2]) stop_target = torch.cat([stop_target1, stop_target2]) size = torch.cat([graph1.num_nodes, graph2.num_nodes]) # add new atom candidates into the size of each graph size_ext = size + self.num_atom_type return node_in_target, node_out_target, bond_target, stop_target, size_ext def _topk_action(self, graph, k): synthon_feature = torch.stack([graph.is_new_node, graph.is_reaction_center], dim=-1).float() node_feature = graph.node_feature.float() + self.input_linear(synthon_feature) output = self.model(graph, node_feature) node_feature = [output["node_feature"]] graph_feature = [] for _feature in sorted(self.feature): if _feature == "reaction": reaction_feature = torch.zeros(len(graph), self.num_reaction, dtype=torch.float32, device=self.device) reaction_feature.scatter_(1, graph.reaction.unsqueeze(-1), 1) graph_feature.append(reaction_feature) elif _feature == "graph": graph_feature.append(output["graph_feature"]) elif _feature == "atom": node_feature.append(graph.node_feature.float()) else: raise ValueError("Unknown feature `%s`" % _feature) graph_feature = torch.cat(graph_feature, dim=-1) # inherit graph features node_feature.append(graph_feature[graph.node2graph]) node_feature = torch.cat(node_feature, dim=-1) new_node_feature = self.new_atom_feature.weight.repeat(len(graph), 1) new_graph_feature = graph_feature.unsqueeze(1).repeat(1, self.num_atom_type, 1).flatten(0, 1) new_node_feature = torch.cat([new_node_feature, new_graph_feature], dim=-1) node_feature, num_nodes_ext = self._extend(node_feature, graph.num_nodes, new_node_feature) node2graph_ext = torch.repeat_interleave(num_nodes_ext) num_cum_nodes_ext = num_nodes_ext.cumsum(0) starts = num_cum_nodes_ext - num_nodes_ext + graph.num_nodes ends = num_cum_nodes_ext is_new_node = functional.multi_slice_mask(starts, ends, num_cum_nodes_ext[-1]) infinity = float("inf") node_in_pred = self.node_in_mlp(node_feature).squeeze(-1) stop_pred = self.stop_mlp(graph_feature).squeeze(-1) # mask out node-in prediction on new atoms node_in_pred[is_new_node] = -infinity node_in_logp = functional.variadic_log_softmax(node_in_pred, num_nodes_ext) # (num_node,) stop_logp = F.logsigmoid(stop_pred) act_logp = F.logsigmoid(-stop_pred) node_in_topk = functional.variadic_topk(node_in_logp, num_nodes_ext, k)[1] assert (node_in_topk >= 0).all() and (node_in_topk < num_nodes_ext.unsqueeze(-1)).all() node_in = node_in_topk + (num_cum_nodes_ext - num_nodes_ext).unsqueeze(-1) # (num_graph, k) # (num_node, node_in_k, feature_dim) node_out_feature = torch.cat([node_feature[node_in][node2graph_ext], node_feature.unsqueeze(1).expand(-1, k, -1)], dim=-1) node_out_pred = self.node_out_mlp(node_out_feature).squeeze(-1) # mask out node-out prediction on self-loops node_out_pred.scatter_(0, node_in, -infinity) # (num_node, node_in_k) node_out_logp = functional.variadic_log_softmax(node_out_pred, num_nodes_ext) # (num_graph, node_out_k, node_in_k) node_out_topk = functional.variadic_topk(node_out_logp, num_nodes_ext, k)[1] assert (node_out_topk >= 0).all() and (node_out_topk < num_nodes_ext.view(-1, 1, 1)).all() node_out = node_out_topk + (num_cum_nodes_ext - num_nodes_ext).view(-1, 1, 1) # (num_graph, node_out_k, node_in_k, feature_dim * 2) edge = torch.stack([node_in.unsqueeze(1).expand_as(node_out), node_out], dim=-1) bond_feature = node_feature[edge].flatten(-2) bond_pred = self.bond_mlp(bond_feature).squeeze(-1) bond_logp = F.log_softmax(bond_pred, dim=-1) # (num_graph, node_out_k, node_in_k, num_relation) bond_type = torch.arange(bond_pred.shape[-1], device=self.device) bond_type = bond_type.view(1, 1, 1, -1).expand_as(bond_logp) # (num_graph, node_out_k, node_in_k, num_relation) node_in_logp = node_in_logp.gather(0, node_in.flatten(0, 1)).view(-1, 1, k, 1) node_out_logp = node_out_logp.gather(0, node_out.flatten(0, 1)).view(-1, k, k, 1) act_logp = act_logp.view(-1, 1, 1, 1) logp = node_in_logp + node_out_logp + bond_logp + act_logp # (num_graph, node_out_k, node_in_k, num_relation, 4) node_in_topk = node_in_topk.view(-1, 1, k, 1).expand_as(logp) node_out_topk = node_out_topk.view(-1, k, k, 1).expand_as(logp) action = torch.stack([node_in_topk, node_out_topk, bond_type, torch.zeros_like(bond_type)], dim=-1) # add stop action logp = torch.cat([logp.flatten(1), stop_logp.unsqueeze(-1)], dim=1) stop = torch.tensor([0, 0, 0, 1], device=self.device) stop = stop.view(1, 1, -1).expand(len(graph), -1, -1) action = torch.cat([action.flatten(1, -2), stop], dim=1) topk = logp.topk(k, dim=-1)[1] return action.gather(1, topk.unsqueeze(-1).expand(-1, -1, 4)), logp.gather(1, topk) def _apply_action(self, graph, action, logp): # only support non-variadic k-actions assert len(graph) == len(action) num_action = action.shape[1] graph = graph.repeat_interleave(num_action) action = action.flatten(0, 1) # (num_graph * k, 4) logp = logp.flatten(0, 1) # (num_graph * k) new_node_in, new_node_out, new_bond_type, stop = action.t() # add new nodes has_new_node = (new_node_out >= graph.num_nodes) & (stop == 0) new_atom_id = (new_node_out - graph.num_nodes)[has_new_node] new_atom_type = self.id2atom[new_atom_id] is_new_node = torch.ones(len(new_atom_type), dtype=torch.bool, device=self.device) is_reaction_center = torch.zeros(len(new_atom_type), dtype=torch.bool, device=self.device) atom_type, num_nodes = functional._extend(graph.atom_type, graph.num_nodes, new_atom_type, has_new_node) is_new_node = functional._extend(graph.is_new_node, graph.num_nodes, is_new_node, has_new_node)[0] is_reaction_center = functional._extend(graph.is_reaction_center, graph.num_nodes, is_reaction_center, has_new_node)[0] # cast to regular node ids new_node_out = torch.where(has_new_node, graph.num_nodes, new_node_out) # modify edges new_edge = torch.stack([new_node_in, new_node_out], dim=-1) edge_list = graph.edge_list.clone() bond_type = graph.bond_type.clone() edge_list[:, :2] -= graph._offsets.unsqueeze(-1) is_modified_edge = (edge_list[:, :2] == new_edge[graph.edge2graph]).all(dim=-1) & \ (stop[graph.edge2graph] == 0) has_modified_edge = scatter_max(is_modified_edge.long(), graph.edge2graph, dim_size=len(graph))[0] > 0 bond_type[is_modified_edge] = new_bond_type[has_modified_edge] edge_list[is_modified_edge, 2] = new_bond_type[has_modified_edge] # modify reverse edges new_edge = new_edge.flip(-1) is_modified_edge = (edge_list[:, :2] == new_edge[graph.edge2graph]).all(dim=-1) & \ (stop[graph.edge2graph] == 0) bond_type[is_modified_edge] = new_bond_type[has_modified_edge] edge_list[is_modified_edge, 2] = new_bond_type[has_modified_edge] # add new edges has_new_edge = (~has_modified_edge) & (stop == 0) new_edge_list = torch.stack([new_node_in, new_node_out, new_bond_type], dim=-1)[has_new_edge] bond_type = functional._extend(bond_type, graph.num_edges, new_bond_type[has_new_edge], has_new_edge)[0] edge_list, num_edges = functional._extend(edge_list, graph.num_edges, new_edge_list, has_new_edge) # add reverse edges new_edge_list = torch.stack([new_node_out, new_node_in, new_bond_type], dim=-1)[has_new_edge] bond_type = functional._extend(bond_type, num_edges, new_bond_type[has_new_edge], has_new_edge)[0] edge_list, num_edges = functional._extend(edge_list, num_edges, new_edge_list, has_new_edge) logp = logp + graph.logp # inherit attributes data_dict = graph.data_dict meta_dict = graph.meta_dict for key in ["atom_type", "bond_type", "is_new_node", "is_reaction_center", "logp"]: data_dict.pop(key) # pad 0 for node / edge attributes for k, v in data_dict.items(): if "node" in meta_dict[k]: shape = (len(new_atom_type), *v.shape[1:]) new_data = torch.zeros(shape, dtype=v.dtype, device=self.device) data_dict[k] = functional._extend(v, graph.num_nodes, new_data, has_new_node)[0] if "edge" in meta_dict[k]: shape = (len(new_edge_list) * 2, *v.shape[1:]) new_data = torch.zeros(shape, dtype=v.dtype, device=self.device) data_dict[k] = functional._extend(v, graph.num_edges, new_data, has_new_edge * 2)[0] new_graph = type(graph)(edge_list, atom_type=atom_type, bond_type=bond_type, num_nodes=num_nodes, num_edges=num_edges, num_relation=graph.num_relation, is_new_node=is_new_node, is_reaction_center=is_reaction_center, logp=logp, meta_dict=meta_dict, **data_dict) with new_graph.graph(): new_graph.is_stopped = stop == 1 valid = logp > float("-inf") new_graph = new_graph[valid] new_graph, feature_valid = self._update_molecule_feature(new_graph) return new_graph[feature_valid] @torch.no_grad() def predict_reactant(self, batch, num_beam=10, max_prediction=20, max_step=20): if "synthon" in batch: synthon = batch["synthon"] synthon2product = torch.repeat_interleave(batch["num_synthon"]) assert (synthon2product < len(batch["reaction"])).all() reaction = batch["reaction"][synthon2product] else: reactant, synthon = batch["graph"] reaction = batch["reaction"] # In any case, ensure that the synthon is a molecule rather than an ion # This is consistent across train/test routines in synthon completion synthon, feature_valid = self._update_molecule_feature(synthon) synthon = synthon[feature_valid] reaction = reaction[feature_valid] graph = synthon with graph.graph(): # for convenience, because we need to manipulate graph a lot graph.reaction = reaction graph.synthon_id = torch.arange(len(graph), device=graph.device) if not hasattr(graph, "logp"): graph.logp = torch.zeros(len(graph), device=graph.device) with graph.node(): graph.is_new_node = torch.zeros(graph.num_node, dtype=torch.bool, device=graph.device) graph.is_reaction_center = (graph.atom_map > 0) & \ (graph.atom_map.unsqueeze(-1) == graph.reaction_center[graph.node2graph]).any(dim=-1) result = [] num_prediction = torch.zeros(len(synthon), dtype=torch.long, device=self.device) for i in range(max_step): logger.warning("action step: %d" % i) logger.warning("batched beam size: %d" % len(graph)) # each candidate has #beam actions action, logp = self._topk_action(graph, num_beam) # each candidate is expanded to at most #beam (depending on validity) new candidates new_graph = self._apply_action(graph, action, logp) # assert (new_graph[is_stopped].logp > float("-inf")).all() offset = -2 * (new_graph.logp.max() - new_graph.logp.min()) key = new_graph.synthon_id * offset + new_graph.logp order = key.argsort(descending=True) new_graph = new_graph[order] num_candidate = new_graph.synthon_id.bincount(minlength=len(synthon)) topk = functional.variadic_topk(new_graph.logp, num_candidate, num_beam)[1] topk_index = topk + (num_candidate.cumsum(0) - num_candidate).unsqueeze(-1) topk_index = torch.unique(topk_index) new_graph = new_graph[topk_index] result.append(new_graph[new_graph.is_stopped]) num_added = scatter_add(new_graph.is_stopped.long(), new_graph.synthon_id, dim_size=len(synthon)) num_prediction += num_added # remove samples that already hit max prediction is_continue = (~new_graph.is_stopped) & (num_prediction[new_graph.synthon_id] < max_prediction) graph = new_graph[is_continue] if len(graph) == 0: break result = self._cat(result) # sort by synthon id order = result.synthon_id.argsort() result = result[order] # remove duplicate predictions is_duplicate = [] synthon_id = -1 for graph in result: if graph.synthon_id != synthon_id: synthon_id = graph.synthon_id smiles_set = set() smiles = graph.to_smiles(isomeric=False, atom_map=False, canonical=True) is_duplicate.append(smiles in smiles_set) smiles_set.add(smiles) is_duplicate = torch.tensor(is_duplicate, device=self.device) result = result[~is_duplicate] num_prediction = result.synthon_id.bincount(minlength=len(synthon)) # remove extra predictions topk = functional.variadic_topk(result.logp, num_prediction, max_prediction)[1] topk_index = topk + (num_prediction.cumsum(0) - num_prediction).unsqueeze(-1) topk_index = topk_index.flatten(0) topk_index_shifted = torch.cat([-torch.ones(1, dtype=torch.long, device=self.device), topk_index[:-1]]) is_duplicate = topk_index == topk_index_shifted result = result[topk_index[~is_duplicate]] return result # (< num_graph * max_prediction) def _extend(self, data, num_xs, input, input2graph=None): if input2graph is None: num_input_per_graph = len(input) // len(num_xs) input2graph = torch.arange(len(num_xs), device=data.device).unsqueeze(-1) input2graph = input2graph.repeat(1, num_input_per_graph).flatten() num_inputs = input2graph.bincount(minlength=len(num_xs)) new_num_xs = num_xs + num_inputs new_num_cum_xs = new_num_xs.cumsum(0) new_num_x = new_num_cum_xs[-1].item() new_data = torch.zeros(new_num_x, *data.shape[1:], dtype=data.dtype, device=data.device) starts = new_num_cum_xs - new_num_xs ends = starts + num_xs index = functional.multi_slice_mask(starts, ends, new_num_x) new_data[index] = data new_data[~index] = input return new_data, new_num_xs def predict_and_target(self, batch, all_loss=None, metric=None): reactant, synthon = batch["graph"] reactant = reactant.clone() with reactant.graph(): reactant.reaction = batch["reaction"] graph1, node_in_target1, node_out_target1, bond_target1, stop_target1 = self.all_edge(reactant, synthon) graph2, node_in_target2, node_out_target2, bond_target2, stop_target2 = self.all_stop(reactant, synthon) graph = self._cat([graph1, graph2]) node_in_target = torch.cat([node_in_target1, node_in_target2]) node_out_target = torch.cat([node_out_target1, node_out_target2]) bond_target = torch.cat([bond_target1, bond_target2]) stop_target = torch.cat([stop_target1, stop_target2]) size = graph.num_nodes # add new atom candidates into the size of each graph size_ext = size + self.num_atom_type synthon_feature = torch.stack([graph.is_new_node, graph.is_reaction_center], dim=-1).float() node_feature = graph.node_feature.float() + self.input_linear(synthon_feature) output = self.model(graph, node_feature, all_loss, metric) node_feature = [output["node_feature"]] graph_feature = [] for _feature in sorted(self.feature): if _feature == "reaction": reaction_feature = torch.zeros(len(graph), self.num_reaction, dtype=torch.float32, device=self.device) reaction_feature.scatter_(1, graph.reaction.unsqueeze(-1), 1) graph_feature.append(reaction_feature) elif _feature == "graph": graph_feature.append(output["graph_feature"]) elif _feature == "atom": node_feature.append(graph.node_feature) else: raise ValueError("Unknown feature `%s`" % _feature) graph_feature = torch.cat(graph_feature, dim=-1) # inherit graph features node_feature.append(graph_feature[graph.node2graph]) node_feature = torch.cat(node_feature, dim=-1) new_node_feature = self.new_atom_feature.weight.repeat(len(graph), 1) new_graph_feature = graph_feature.unsqueeze(1).repeat(1, self.num_atom_type, 1).flatten(0, 1) new_node_feature = torch.cat([new_node_feature, new_graph_feature], dim=-1) node_feature, num_nodes_ext = self._extend(node_feature, graph.num_nodes, new_node_feature) assert (num_nodes_ext == size_ext).all() node2graph_ext = torch.repeat_interleave(num_nodes_ext) num_cum_nodes_ext = num_nodes_ext.cumsum(0) starts = num_cum_nodes_ext - num_nodes_ext + graph.num_nodes ends = num_cum_nodes_ext is_new_node = functional.multi_slice_mask(starts, ends, num_cum_nodes_ext[-1]) node_in = node_in_target + num_cum_nodes_ext - num_nodes_ext node_out = node_out_target + num_cum_nodes_ext - num_nodes_ext edge = torch.stack([node_in, node_out], dim=-1) node_out_feature = torch.cat([node_feature[node_in][node2graph_ext], node_feature], dim=-1) bond_feature = node_feature[edge].flatten(-2) node_in_pred = self.node_in_mlp(node_feature).squeeze(-1) node_out_pred = self.node_out_mlp(node_out_feature).squeeze(-1) bond_pred = self.bond_mlp(bond_feature).squeeze(-1) stop_pred = self.stop_mlp(graph_feature).squeeze(-1) infinity = torch.tensor(float("inf"), device=self.device) # mask out node-in prediction on new atoms node_in_pred[is_new_node] = -infinity # mask out node-out prediction on self-loops node_out_pred[node_in] = -infinity return (node_in_pred, node_out_pred, bond_pred, stop_pred), \ (node_in_target, node_out_target, bond_target, stop_target, size_ext)
[docs]@R.register("tasks.Retrosynthesis") class Retrosynthesis(tasks.Task, core.Configurable): """ Retrosynthesis task. This class wraps pretrained center identification and synthon completion modeules into a pipeline. Parameters: center_identification (CenterIdentification): sub task of center identification synthon_completion (SynthonCompletion): sub task of synthon completion center_topk (int, optional): number of reaction centers to predict for each product num_synthon_beam (int, optional): size of beam search for each synthon max_prediction (int, optional): max number of final predictions for each product metric (str or list of str, optional): metric(s). Available metrics are ``top-K``. """ _option_members = {"metric"} def __init__(self, center_identification, synthon_completion, center_topk=2, num_synthon_beam=10, max_prediction=20, metric=("top-1", "top-3", "top-5", "top-10")): super(Retrosynthesis, self).__init__() self.center_identification = center_identification self.synthon_completion = synthon_completion self.center_topk = center_topk self.num_synthon_beam = num_synthon_beam self.max_prediction = max_prediction self.metric = metric
[docs] def load_state_dict(self, state_dict, strict=True): if not strict: raise ValueError("Retrosynthesis only supports load_state_dict() with strict=True") keys = set(state_dict.keys()) for model in [self.center_identification, self.synthon_completion]: if set(model.state_dict().keys()) == keys: return model.load_state_dict(state_dict, strict) raise RuntimeError("Neither of sub modules matches with state_dict")
def predict(self, batch, all_loss=None, metric=None): synthon_batch = self.center_identification.predict_synthon(batch, self.center_topk) synthon = synthon_batch["synthon"] num_synthon = synthon_batch["num_synthon"] assert (num_synthon >= 1).all() and (num_synthon <= 2).all() synthon2split = torch.repeat_interleave(num_synthon) with synthon.graph(): synthon.reaction_center = synthon_batch["reaction_center"][synthon2split] synthon.split_logp = synthon_batch["log_likelihood"][synthon2split] reactant = self.synthon_completion.predict_reactant(synthon_batch, self.num_synthon_beam, self.max_prediction) logps = [] reactant_ids = [] product_ids = [] # case 1: one synthon is_single = num_synthon[synthon2split[reactant.synthon_id]] == 1 reactant_id = is_single.nonzero().squeeze(-1) logps.append(reactant.split_logp[reactant_id] + reactant.logp[reactant_id]) product_ids.append(reactant.product_id[reactant_id]) # pad -1 reactant_ids.append(torch.stack([reactant_id, -torch.ones_like(reactant_id)], dim=-1)) # case 2: two synthons # use proposal to avoid O(n^2) complexity reactant1 = torch.arange(len(reactant), device=self.device) reactant1 = reactant1.unsqueeze(-1).expand(-1, self.max_prediction * 2) reactant2 = reactant1 + torch.arange(self.max_prediction * 2, device=self.device) valid = reactant2 < len(reactant) reactant1 = reactant1[valid] reactant2 = reactant2[valid] synthon1 = reactant.synthon_id[reactant1] synthon2 = reactant.synthon_id[reactant2] valid = (synthon1 < synthon2) & (synthon2split[synthon1] == synthon2split[synthon2]) reactant1 = reactant1[valid] reactant2 = reactant2[valid] logps.append(reactant.split_logp[reactant1] + reactant.logp[reactant1] + reactant.logp[reactant2]) product_ids.append(reactant.product_id[reactant1]) reactant_ids.append(torch.stack([reactant1, reactant2], dim=-1)) # combine case 1 & 2 logps = torch.cat(logps) reactant_ids = torch.cat(reactant_ids) product_ids = torch.cat(product_ids) order = product_ids.argsort() logps = logps[order] reactant_ids = reactant_ids[order] num_prediction = product_ids.bincount() logps, topk = functional.variadic_topk(logps, num_prediction, self.max_prediction) topk_index = topk + (num_prediction.cumsum(0) - num_prediction).unsqueeze(-1) topk_index_shifted = torch.cat([-torch.ones(len(topk_index), 1, dtype=torch.long, device=self.device), topk_index[:, :-1]], dim=-1) is_duplicate = topk_index == topk_index_shifted reactant_id = reactant_ids[topk_index] # (num_graph, k, 2) # why we need to repeat the graph? # because reactant_id may be duplicated, which is not directly supported by graph indexing is_padding = reactant_id == -1 num_synthon = (~is_padding).sum(dim=-1) num_synthon = num_synthon[~is_duplicate] logps = logps[~is_duplicate] offset = torch.arange(self.max_prediction, device=self.device) * len(reactant) reactant_id = reactant_id + offset.view(1, -1, 1) reactant_id = reactant_id[~(is_padding | is_duplicate.unsqueeze(-1))] reactant = reactant.repeat(self.max_prediction) reactant = reactant[reactant_id] assert num_synthon.sum() == len(reactant) synthon2graph = torch.repeat_interleave(num_synthon) first_synthon = num_synthon.cumsum(0) - num_synthon # inherit graph attributes from the first synthon data_dict = reactant.data_mask(graph_index=first_synthon, include="graph")[0] # merge synthon pairs from the same split into a single graph reactant = reactant.merge(synthon2graph) with reactant.graph(): for k, v in data_dict.items(): setattr(reactant, k, v) reactant.logps = logps num_prediction = reactant.product_id.bincount() return reactant, num_prediction # (num_graph * k) def target(self, batch): reactant, product = batch["graph"] reactant = reactant.ion_to_molecule() return reactant def evaluate(self, pred, target): pred, num_prediction = pred infinity = torch.iinfo(torch.long).max - 1 metric = {} ranking = [] # any better solution for parallel graph isomorphism? num_cum_prediction = num_prediction.cumsum(0) for i in range(len(target)): target_smiles = target[i].to_smiles(isomeric=False, atom_map=False, canonical=True) offset = (num_cum_prediction[i] - num_prediction[i]).item() for j in range(num_prediction[i]): pred_smiles = pred[offset + j].to_smiles(isomeric=False, atom_map=False, canonical=True) if pred_smiles == target_smiles: break else: j = infinity ranking.append(j + 1) ranking = torch.tensor(ranking, device=self.device) for _metric in self.metric: if _metric.startswith("top-"): threshold = int(_metric[4:]) score = (ranking <= threshold).float().mean() metric["top-%d accuracy" % threshold] = score else: raise ValueError("Unknown metric `%s`" % _metric) return metric