Source code for torchdrug.tasks.generation

import copy
import logging
import warnings
from collections import defaultdict

from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch import nn
from torch_scatter import scatter_add, scatter_max
from torch_scatter.composite import scatter_log_softmax, scatter_softmax

from torchdrug import core, data, tasks, metrics, transforms
from torchdrug.core import Registry as R
from torchdrug.layers import functional
from torchdrug import layers


logger = logging.getLogger(__name__)


[docs]@R.register("tasks.AutoregressiveGeneration") class AutoregressiveGeneration(tasks.Task, core.Configurable): """ Autoregressive graph generation task. This class can be used to implement GraphAF proposed in `GraphAF: A Flow-based Autoregressive Model for Molecular Graph Generation`_. To do so, instantiate the node model and the edge model with two :class:`GraphAutoregressiveFlow <torchdrug.models.GraphAutoregressiveFlow>` models. .. _GraphAF\: A Flow-based Autoregressive Model for Molecular Graph Generation: https://arxiv.org/pdf/2001.09382.pdf Parameters: node_model (nn.Module): node likelihood model edge_model (nn.Module): edge likelihood model task (str or list of str, optional): property optimization task(s). Available tasks are ``plogp`` and ``qed``. num_node_sample (int, optional): number of node samples per graph. -1 for all samples. num_edge_sample (int, optional): number of edge samples per graph. -1 for all samples. max_edge_unroll (int, optional): max node id difference. If not provided, use the statistics from the training set. max_node (int, optional): max number of node. If not provided, use the statistics from the training set. criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are ``nll`` and ``ppo``. agent_update_interval (int, optional): update agent every n batch gamma (float, optional): reward discount rate reward_temperature (float, optional): temperature for reward. Higher temperature encourages larger mean reward, while lower temperature encourages larger maximal reward. baseline_momentum (float, optional): momentum for value function baseline """ eps = 1e-10 top_k = 10 _option_members = set(["task", "criterion"]) def __init__(self, node_model, edge_model, task=(), num_node_sample=-1, num_edge_sample=-1, max_edge_unroll=None, max_node=None, criterion="nll", agent_update_interval=5, gamma=0.9, reward_temperature=1, baseline_momentum=0.9): super(AutoregressiveGeneration, self).__init__() self.node_model = node_model self.edge_model = edge_model self.agent_node_model = copy.deepcopy(node_model) self.agent_edge_model = copy.deepcopy(edge_model) self.task = task self.num_atom_type = self.node_model.input_dim self.num_bond_type = self.edge_model.input_dim self.num_node_sample = num_node_sample self.num_edge_sample = num_edge_sample self.max_edge_unroll = max_edge_unroll self.max_node = max_node self.criterion = criterion self.agent_update_interval = agent_update_interval self.gamma = gamma self.reward_temperature = reward_temperature self.baseline_momentum = baseline_momentum self.best_results = defaultdict(list) self.batch_id = 0
[docs] def preprocess(self, train_set, valid_set, test_set): """ Add atom id mapping and random BFS order to the training set. Compute ``max_edge_unroll`` and ``max_node`` on the training set if not provided. """ remap_atom_type = transforms.RemapAtomType(train_set.atom_types) train_set.transform = transforms.Compose([ train_set.transform, remap_atom_type, transforms.RandomBFSOrder(), ]) self.register_buffer("id2atom", remap_atom_type.id2atom) self.register_buffer("atom2id", remap_atom_type.atom2id) if self.max_edge_unroll is None or self.max_node is None: self.max_edge_unroll = 0 self.max_node = 0 train_set = tqdm(train_set, "Computing max number of nodes and edge unrolling") for data in train_set: graph = data["graph"] if graph.edge_list.numel(): edge_unroll = (graph.edge_list[:, 0] - graph.edge_list[:, 1]).abs().max().item() self.max_edge_unroll = max(self.max_edge_unroll, edge_unroll) self.max_node = max(self.max_node, graph.num_node) logger.warning("max node = %d, max edge unroll = %d" % (self.max_node, self.max_edge_unroll)) self.register_buffer("node_baseline", torch.zeros(self.max_node + 1)) self.register_buffer("edge_baseline", torch.zeros(self.max_node + 1))
def forward(self, batch): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} for criterion, weight in self.criterion.items(): if criterion == "nll": _loss, _metric = self.density_estimation_forward(batch) all_loss += _loss * weight metric.update(_metric) elif criterion == "ppo": _loss, _metric = self.reinforce_forward(batch) all_loss += _loss * weight metric.update(_metric) else: raise ValueError("Unknown criterion `%s`" % criterion) return all_loss, metric def reinforce_forward(self, batch): all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} if self.batch_id % self.agent_update_interval == 0: self.agent_node_model.load_state_dict(self.node_model.state_dict()) self.agent_edge_model.load_state_dict(self.edge_model.state_dict()) self.batch_id += 1 # generation takes less time when early_stop=True graph = self.generate(len(batch["graph"]), off_policy=True, early_stop=True) if graph.num_nodes.max() == 1: raise ValueError("Generation results collapse to singleton molecules") reward = torch.zeros(len(graph), device=self.device) for task in self.task: if task == "plogp": plogp = metrics.penalized_logP(graph) metric["Penalized logP"] = plogp.mean() metric["Penalized logP (max)"] = plogp.max() self.update_best_result(graph, plogp, "Penalized logP") reward += (plogp / self.reward_temperature).exp() if plogp.max().item() > 5: print("Penalized logP max = %s" % plogp.max().item()) print(self.best_results["Penalized logP"]) elif task == "qed": qed = metrics.QED(graph) metric["QED"] = qed.mean() metric["QED (max)"] = qed.max() self.update_best_result(graph, qed, "QED") reward += (qed / self.reward_temperature).exp() #reward += qed * 3 if qed.max().item() > 0.93: print("QED max = %s" % qed.max().item()) print(self.best_results["QED"]) else: raise ValueError("Unknown task `%s`" % task) # these graph-level features will broadcast to all masked graphs with graph.graph(): graph.reward = reward graph.original_num_nodes = graph.num_nodes graph.atom_type = self.atom2id[graph.atom_type] is_training = self.training # easily got nan if BN is trained self.bn_eval() masked_graph, node_target = self.mask_node(graph, metric) # reward reshaping reward = masked_graph.reward masked_graph.atom_type = self.id2atom[masked_graph.atom_type] reward = reward * self.gamma ** (masked_graph.original_num_nodes - masked_graph.num_nodes).float() # per graph size reward baseline weight = torch.ones_like(masked_graph.num_nodes, dtype=torch.float) baseline = scatter_add(reward, masked_graph.num_nodes, dim_size=self.max_node + 1) / \ (scatter_add(weight, masked_graph.num_nodes, dim_size=self.max_node + 1) + self.eps) self.node_baseline = self.node_baseline * self.baseline_momentum + baseline * (1 - self.baseline_momentum) reward -= self.node_baseline[masked_graph.num_nodes] reward += masked_graph.is_valid masked_graph.atom_type = self.atom2id[masked_graph.atom_type] log_likelihood = self.node_model(masked_graph, node_target, None, all_loss, metric) agent_log_likelihood = self.agent_node_model(masked_graph, node_target, None, all_loss, metric) objective = functional.clipped_policy_gradient_objective(log_likelihood, agent_log_likelihood, reward) objective = objective.mean() metric["node PPO objective"] = objective all_loss += -objective masked_graph, edge_target, edge = self.mask_edge(graph, metric) # reward reshaping reward = masked_graph.reward masked_graph.atom_type = self.id2atom[masked_graph.atom_type] reward = reward * self.gamma ** (masked_graph.original_num_nodes - masked_graph.num_nodes).float() # per graph size reward baseline weight = torch.ones_like(masked_graph.num_nodes, dtype=torch.float) baseline = scatter_add(reward, masked_graph.num_nodes, dim_size=self.max_node + 1) / \ (scatter_add(weight, masked_graph.num_nodes, dim_size=self.max_node + 1) + self.eps) self.edge_baseline = self.edge_baseline * self.baseline_momentum + baseline * (1 - self.baseline_momentum) reward -= self.edge_baseline[masked_graph.num_nodes] reward += masked_graph.is_valid masked_graph.atom_type = self.atom2id[masked_graph.atom_type] log_likelihood = self.edge_model(masked_graph, edge_target, edge, all_loss, metric) agent_log_likelihood = self.agent_edge_model(masked_graph, edge_target, edge, all_loss, metric) objective = functional.clipped_policy_gradient_objective(log_likelihood, agent_log_likelihood, reward) objective = objective.mean() metric["edge PPO objective"] = objective all_loss += -objective self.bn_train(is_training) return all_loss, metric def density_estimation_forward(self, batch): all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} graph = batch["graph"] masked_graph, node_target = self.mask_node(graph, metric) log_likelihood = self.node_model(masked_graph, node_target, None, all_loss, metric) log_likelihood = log_likelihood.mean() metric["node log likelihood"] = log_likelihood all_loss += -log_likelihood masked_graph, edge_target, edge = self.mask_edge(graph, metric) log_likelihood = self.edge_model(masked_graph, edge_target, edge, all_loss, metric) log_likelihood = log_likelihood.mean() metric["edge log likelihood"] = log_likelihood all_loss += -log_likelihood return all_loss, metric def evaluate(self, batch): pred = None metric = {} graph, target = self.all_node(batch["graph"]) log_likelihood = self.node_model(graph, target) log_likelihood = log_likelihood.mean() metric["node log likelihood"] = log_likelihood graph, target = self.all_edge(batch["graph"]) log_likelihood = self.edge_model(graph, target) log_likelihood = log_likelihood.mean() metric["edge log likelihood"] = log_likelihood return pred, metric def bn_train(self, mode=True): for module in self.modules(): if isinstance(module, nn.BatchNorm1d): module.train(mode) def bn_eval(self): for module in self.modules(): if isinstance(module, nn.BatchNorm1d): module.eval() def update_best_result(self, graph, score, task): score = score.cpu() best_results = self.best_results[task] for s, i in zip(*score.sort(descending=True)): s = s.item() i = i.item() if len(best_results) == self.top_k and s < best_results[-1][0]: break best_results.append((s, graph[i].to_smiles())) best_results.sort(reverse=True) best_results = best_results[:self.top_k] self.best_results[task] = best_results @torch.no_grad() def generate(self, num_sample, max_resample=20, off_policy=False, early_stop=False, verbose=0): num_relation = self.num_bond_type - 1 is_training = self.training self.eval() if off_policy: node_model = self.agent_node_model edge_model = self.agent_edge_model else: node_model = self.node_model edge_model = self.edge_model edge_list = torch.zeros(0, 3, dtype=torch.long, device=self.device) num_nodes = torch.zeros(num_sample, dtype=torch.long, device=self.device) num_edges = torch.zeros_like(num_nodes) atom_type = torch.zeros(0, dtype=torch.long, device=self.device) graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges, num_relation=num_relation) completed = torch.zeros(num_sample, dtype=torch.bool, device=self.device) for node_in in range(self.max_node): atom_pred = node_model.sample(graph) # why we add atom_pred even if it is completed? # because we need to batch edge model over (node_in, node_out), even on completed graphs atom_type, num_nodes = self._append(atom_type, num_nodes, atom_pred) graph = node_graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges, num_relation=num_relation) start = max(0, node_in - self.max_edge_unroll) for node_out in range(start, node_in): is_valid = completed.clone() edge = torch.tensor([node_in, node_out], device=self.device).repeat(num_sample, 1) # default: non-edge bond_pred = (self.num_bond_type - 1) * torch.ones(num_sample, dtype=torch.long, device=self.device) for i in range(max_resample): # only resample invalid graphs mask = ~is_valid bond_pred[mask] = edge_model.sample(graph, edge)[mask] # check valency mask = (bond_pred < edge_model.input_dim - 1) & ~completed edge_pred = torch.cat([edge, bond_pred.unsqueeze(-1)], dim=-1) tmp_edge_list, tmp_num_edges = self._append(edge_list, num_edges, edge_pred, mask) edge_pred = torch.cat([edge.flip(-1), bond_pred.unsqueeze(-1)], dim=-1) tmp_edge_list, tmp_num_edges = self._append(tmp_edge_list, tmp_num_edges, edge_pred, mask) tmp_graph = data.PackedMolecule(tmp_edge_list, self.id2atom[atom_type], tmp_edge_list[:, -1], num_nodes, tmp_num_edges, num_relation=num_relation) is_valid = tmp_graph.is_valid | completed if is_valid.all(): break if not is_valid.all() and verbose: num_invalid = num_sample - is_valid.sum().item() num_working = num_sample - completed.sum().item() logger.warning("edge (%d, %d): %d / %d molecules are invalid even after %d resampling" % (node_in, node_out, num_invalid, num_working, max_resample)) mask = (bond_pred < edge_model.input_dim - 1) & ~completed edge_pred = torch.cat([edge, bond_pred.unsqueeze(-1)], dim=-1) edge_list, num_edges = self._append(edge_list, num_edges, edge_pred, mask) edge_pred = torch.cat([edge.flip(-1), bond_pred.unsqueeze(-1)], dim=-1) edge_list, num_edges = self._append(edge_list, num_edges, edge_pred, mask) graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges, num_relation=num_relation) if node_in > 0: assert (graph.num_edges[completed] == node_graph.num_edges[completed]).all() completed |= graph.num_edges == node_graph.num_edges if early_stop: graph.atom_type = self.id2atom[graph.atom_type] completed |= ~graph.is_valid graph.atom_type = self.atom2id[graph.atom_type] if completed.all(): break self.train(is_training) # remove isolated atoms index = graph.degree_out > 0 # keep at least the first atom for each graph index[graph.num_cum_nodes - graph.num_nodes] = 1 graph = graph.subgraph(index) graph.atom_type = self.id2atom[graph.atom_type] graph = graph[graph.is_valid_rdkit] return graph def _append(self, data, num_xs, input, mask=None): if mask is None: mask = torch.ones_like(num_xs, dtype=torch.bool) new_num_xs = num_xs + mask new_num_cum_xs = new_num_xs.cumsum(0) new_num_x = new_num_cum_xs[-1].item() new_data = torch.zeros(new_num_x, *data.shape[1:], dtype=data.dtype, device=data.device) starts = new_num_cum_xs - new_num_xs ends = starts + num_xs index = functional.multi_slice_mask(starts, ends, new_num_x) new_data[index] = data new_data[~index] = input[mask] return new_data, new_num_xs @torch.no_grad() def mask_node(self, graph, metric=None): if self.num_node_sample == -1: masked_graph, node_target = self.all_node(graph) if metric is not None: metric["node mask / graph"] = torch.tensor([len(masked_graph) / len(graph)], device=graph.device) else: masked_graph, node_target = self.sample_node(graph, self.num_node_sample) return masked_graph, node_target @torch.no_grad() def mask_edge(self, graph, metric=None): if self.num_edge_sample == -1: masked_graph, edge_target, edge = self.all_edge(graph) if metric is not None: metric["edge mask / graph"] = torch.tensor([len(masked_graph) / len(graph)], device=graph.device) else: masked_graph, edge_target, edge = self.sample_edge(graph, self.num_edge_sample) return masked_graph, edge_target, edge @torch.no_grad() def sample_node(self, graph, num_sample): graph = graph.repeat(num_sample) num_nodes = graph.num_nodes num_keep_nodes = torch.rand(len(graph), device=graph.device) * num_nodes # [0, num_nodes) num_keep_nodes = num_keep_nodes.long() # [0, num_nodes - 1] starts = graph.num_cum_nodes - graph.num_nodes ends = starts + num_keep_nodes mask = functional.multi_slice_mask(starts, ends, graph.num_node) new_graph = graph.subgraph(mask) target = graph.subgraph(ends).atom_type return new_graph, target @torch.no_grad() def all_node(self, graph): starts, ends, valid = self._all_prefix_slice(graph.num_nodes) num_repeat = len(starts) // len(graph) graph = graph.repeat(num_repeat) mask = functional.multi_slice_mask(starts, ends, graph.num_node) new_graph = graph.subgraph(mask) target = graph.subgraph(ends).atom_type return new_graph[valid], target[valid] @torch.no_grad() def sample_edge(self, graph, num_sample): if (graph.num_nodes < 2).any(): graph = graph[graph.num_nodes >= 2] warnings.warn("Graphs with less than 2 nodes can't be used for edge generation learning. Dropped") lengths = self._valid_edge_prefix_lengths(graph) graph = graph.repeat(num_sample) num_max_node = graph.num_nodes.max().item() num_node2num_dense_edge = torch.arange(num_max_node + 1, device=graph.device) ** 2 num_node2length_idx = (lengths.unsqueeze(-1) < num_node2num_dense_edge.unsqueeze(0)).sum(dim=0) # uniformly sample a mask from each graph's valid masks length_indexes = torch.rand(len(graph), device=graph.device) * num_node2length_idx[graph.num_nodes] length_indexes = length_indexes.long() num_keep_dense_edges = lengths[length_indexes] # undirected: all upper triangular edge ids are flipped to lower triangular ids # 1 -> 2, 4 -> 6, 5 -> 7 node_index = graph.edge_list[:, :2] - graph._offsets.unsqueeze(-1) node_in, node_out = node_index.t() node_large = node_index.max(dim=-1)[0] node_small = node_index.min(dim=-1)[0] edge_id = node_large ** 2 + (node_in >= node_out) * node_large + node_small undirected_edge_id = node_large * (node_large + 1) + node_small edge_mask = undirected_edge_id < num_keep_dense_edges[graph.edge2graph] circum_box_size = (num_keep_dense_edges + 1.0).sqrt().ceil().long() starts = graph.num_cum_nodes - graph.num_nodes ends = starts + circum_box_size node_mask = functional.multi_slice_mask(starts, ends, graph.num_node) # compact nodes so that succeeding nodes won't affect graph pooling new_graph = graph.edge_mask(edge_mask).node_mask(node_mask, compact=True) positive_edge = edge_id == num_keep_dense_edges[graph.edge2graph] positive_graph = scatter_add(positive_edge.long(), graph.edge2graph, dim=0, dim_size=len(graph)).bool() # default: non-edge target = (self.num_bond_type - 1) * torch.ones(graph.batch_size, dtype=torch.long, device=graph.device) target[positive_graph] = graph.edge_list[positive_edge, 2] node_in = circum_box_size - 1 node_out = num_keep_dense_edges - node_in * circum_box_size edge = torch.stack([node_in, node_out], dim=-1) return new_graph, target, edge @torch.no_grad() def all_edge(self, graph): if (graph.num_nodes < 2).any(): graph = graph[graph.num_nodes >= 2] warnings.warn("Graphs with less than 2 nodes can't be used for edge generation learning. Dropped") lengths = self._valid_edge_prefix_lengths(graph) starts, ends, valid = self._all_prefix_slice(graph.num_nodes ** 2, lengths) num_keep_dense_edges = ends - starts num_repeat = len(starts) // len(graph) graph = graph.repeat(num_repeat) # undirected: all upper triangular edge ids are flipped to lower triangular ids # 1 -> 2, 4 -> 6, 5 -> 7 node_index = graph.edge_list[:, :2] - graph._offsets.unsqueeze(-1) node_in, node_out = node_index.t() node_large = node_index.max(dim=-1)[0] node_small = node_index.min(dim=-1)[0] edge_id = node_large ** 2 + (node_in >= node_out) * node_large + node_small undirected_edge_id = node_large * (node_large + 1) + node_small edge_mask = undirected_edge_id < num_keep_dense_edges[graph.edge2graph] circum_box_size = (num_keep_dense_edges + 1.0).sqrt().ceil().long() starts = graph.num_cum_nodes - graph.num_nodes ends = starts + circum_box_size node_mask = functional.multi_slice_mask(starts, ends, graph.num_node) # compact nodes so that succeeding nodes won't affect graph pooling new_graph = graph.edge_mask(edge_mask).node_mask(node_mask, compact=True) positive_edge = edge_id == num_keep_dense_edges[graph.edge2graph] positive_graph = scatter_add(positive_edge.long(), graph.edge2graph, dim=0, dim_size=len(graph)).bool() # default: non-edge target = (self.num_bond_type - 1) * torch.ones(graph.batch_size, dtype=torch.long, device=graph.device) target[positive_graph] = graph.edge_list[positive_edge, 2] node_in = circum_box_size - 1 node_out = num_keep_dense_edges - node_in * circum_box_size edge = torch.stack([node_in, node_out], dim=-1) return new_graph[valid], target[valid], edge[valid] @torch.no_grad() def _all_prefix_slice(self, num_xs, lengths=None): # extract a bunch of slices that correspond to the following num_repeat * n masks # ------ repeat 0 ----- # graphs[0]: [0, 0, ..., 0] # ... # graphs[-1]: [0, 0, ..., 0] # ------ repeat 1 ----- # graphs[0]: [1, 0, ..., 0] # ... # graphs[-1]: [1, 0, ..., 0] # ... # ------ repeat -1 ----- # graphs[0]: [1, ..., 1, 0] # ... # graphs[-1]: [1, ..., 1, 0] num_cum_xs = num_xs.cumsum(0) starts = num_cum_xs - num_xs if lengths is None: num_max_x = num_xs.max().item() lengths = torch.arange(num_max_x, device=num_xs.device) pack_offsets = torch.arange(len(lengths), device=num_xs.device) * num_cum_xs[-1] # starts, lengths, ends: (num_repeat, num_graph) starts = starts.unsqueeze(0) + pack_offsets.unsqueeze(-1) valid = lengths.unsqueeze(-1) <= num_xs.unsqueeze(0) - 1 lengths = torch.min(lengths.unsqueeze(-1), num_xs.unsqueeze(0) - 1) ends = starts + lengths starts = starts.flatten() ends = ends.flatten() valid = valid.flatten() return starts, ends, valid @torch.no_grad() def _valid_edge_prefix_lengths(self, graph): # valid prefix lengths are across a batch, according to the largest graph num_max_node = graph.num_nodes.max().item() # edge id in an adjacency (snake pattern) # in # o 0 1 4 # u 2 3 5 # t 6 7 8 lengths = torch.arange(num_max_node ** 2, device=graph.device) circum_box_size = (lengths + 1.0).sqrt().ceil().long() # only keep lengths that ends in the lower triangular part of adjacency matrix lengths = lengths[lengths >= circum_box_size * (circum_box_size - 1)] # lengths: [0, 2, 3, 6, 7, 8, ...] # num_node2length_idx: [0, 1, 4, 6, ...] # num_edge_unrolls # 0 # 1 0 # 2 1 0 num_edge_unrolls = (lengths + 1.0).sqrt().ceil().long() ** 2 - lengths - 1 # num_edge_unrolls: [0, 1, 0, 2, 1, 0, ...] # remove lengths that unroll too much. they always lead to empty targets lengths = lengths[(num_edge_unrolls <= self.max_edge_unroll) & (num_edge_unrolls > 0)] return lengths
[docs]@R.register("tasks.GCPNGeneration") class GCPNGeneration(tasks.Task, core.Configurable): """ The graph generative model from `Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation`_. .. _Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation: https://papers.nips.cc/paper/7877-graph-convolutional-policy-network-for-goal-directed-molecular-graph-generation.pdf Parameters: model (nn.Module): graph representation model atom_types (list or set): set of all possible atom types task (str or list of str, optional): property optimization task(s) max_edge_unroll (int, optional): max node id difference. If not provided, use the statistics from the training set. max_node (int, optional): max number of node. If not provided, use the statistics from the training set. criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are ``nll`` and ``ppo``. agent_update_interval (int, optional): update the agent every n batch gamma (float, optional): reward discount rate reward_temperature (float, optional): temperature for reward. Higher temperature encourages larger mean reward, while lower temperature encourages larger maximal reward. baseline_momentum (float, optional): momentum for value function baseline """ eps = 1e-10 top_k = 10 _option_members = set(["task", "criterion"]) def __init__(self, model, atom_types, max_edge_unroll=None, max_node=None, task=(), criterion="nll", hidden_dim_mlp=128, agent_update_interval=10, gamma=0.9, reward_temperature=1, baseline_momentum=0.9): super(GCPNGeneration, self).__init__() self.model = model self.task = task self.max_edge_unroll = max_edge_unroll self.max_node = max_node self.criterion = criterion self.hidden_dim_mlp = hidden_dim_mlp self.agent_update_interval = agent_update_interval self.gamma = gamma self.reward_temperature = reward_temperature self.baseline_momentum = baseline_momentum self.best_results = defaultdict(list) self.batch_id = 0 remap_atom_type = transforms.RemapAtomType(atom_types) self.register_buffer("id2atom", remap_atom_type.id2atom) self.register_buffer("atom2id", remap_atom_type.atom2id) self.new_atom_embeddings = nn.Parameter(torch.zeros(self.id2atom.size(0), self.model.output_dim)) nn.init.normal_(self.new_atom_embeddings, mean=0, std=0.1) self.inp_dim_stop = self.model.output_dim self.mlp_stop = layers.MultiLayerPerceptron(self.inp_dim_stop, [self.hidden_dim_mlp, 2], activation='tanh') self.inp_dim_node1 = self.model.output_dim + self.model.output_dim self.mlp_node1 = layers.MultiLayerPerceptron(self.inp_dim_node1, [self.hidden_dim_mlp, 1], activation='tanh') self.inp_dim_node2 = 2 * self.model.output_dim + self.model.output_dim self.mlp_node2 = layers.MultiLayerPerceptron(self.inp_dim_node2, [self.hidden_dim_mlp, 1], activation='tanh') self.inp_dim_edge = 2 * self.model.output_dim self.mlp_edge = layers.MultiLayerPerceptron(self.inp_dim_edge, [self.hidden_dim_mlp, self.model.num_relation], activation='tanh') self.agent_model = copy.deepcopy(self.model) self.agent_new_atom_embeddings = copy.deepcopy(self.new_atom_embeddings) self.agent_mlp_stop = copy.deepcopy(self.mlp_stop) self.agent_mlp_node1 = copy.deepcopy(self.mlp_node1) self.agent_mlp_node2 = copy.deepcopy(self.mlp_node2) self.agent_mlp_edge = copy.deepcopy(self.mlp_edge)
[docs] def preprocess(self, train_set, valid_set, test_set): """ Add atom id mapping and random BFS order to the training set. Compute ``max_edge_unroll`` and ``max_node`` on the training set if not provided. """ remap_atom_type = transforms.RemapAtomType(train_set.atom_types) train_set.transform = transforms.Compose([ train_set.transform, transforms.RandomBFSOrder(), ]) if self.max_edge_unroll is None or self.max_node is None: self.max_edge_unroll = 0 self.max_node = 0 train_set = tqdm(train_set, "Computing max number of nodes and edge unrolling") for data in train_set: graph = data["graph"] if graph.edge_list.numel(): edge_unroll = (graph.edge_list[:, 0] - graph.edge_list[:, 1]).abs().max().item() self.max_edge_unroll = max(self.max_edge_unroll, edge_unroll) self.max_node = max(self.max_node, graph.num_node) logger.warning("max node = %d, max edge unroll = %d" % (self.max_node, self.max_edge_unroll)) self.register_buffer("moving_baseline", torch.zeros(self.max_node + 1))
def forward(self, batch): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} for criterion, weight in self.criterion.items(): if criterion == "nll": _loss, _metric = self.MLE_forward(batch) all_loss += _loss * weight metric.update(_metric) elif criterion == "ppo": _loss, _metric = self.reinforce_forward(batch) all_loss += _loss * weight metric.update(_metric) else: raise ValueError("Unknown criterion `%s`" % criterion) return all_loss, metric def predict(self, graph, label_dict, use_agent=False): # step1: get node/graph embeddings if not use_agent: output = self.model(graph, graph.node_feature.float()) else: output = self.agent_model(graph, graph.node_feature.float()) extended_node2graph = torch.arange(graph.num_nodes.size(0), device=self.device).unsqueeze(1).repeat([1, self.id2atom.size(0)]).view(-1) # (num_graph * 16) extended_node2graph = torch.cat((graph.node2graph, extended_node2graph)) # (num_node + 16 * num_graph) graph_feature_per_node = output["graph_feature"][extended_node2graph] # step2: predict stop stop_feature = output["graph_feature"] #(num_graph, n_out) if not use_agent: stop_logits = self.mlp_stop(stop_feature) #(num_graph, 2) else: stop_logits = self.agent_mlp_stop(stop_feature) #(num_graph, 2) if label_dict == None: return stop_logits # step3: predict first node: node1 node1_feature = output["node_feature"] #(num_node, n_out) node1_feature = torch.cat((node1_feature, self.new_atom_embeddings.repeat([graph.num_nodes.size(0), 1])), 0) # (num_node + 16 * num_graph, n_out) node2_feature_node2 = node1_feature.clone() # (num_node + 16 * num_graph, n_out) # cat graph emb node1_feature = torch.cat((node1_feature, graph_feature_per_node), 1) if not use_agent: node1_logits = self.mlp_node1(node1_feature).squeeze(1) #(num_node + 16 * num_graph) else: node1_logits = self.agent_mlp_node1(node1_feature).squeeze(1) #(num_node + 16 * num_graph) #mask the extended part mask = torch.zeros(node1_logits.size(), device=self.device) mask[:graph.num_node] = 1 node1_logits = torch.where(mask>0, node1_logits, -10000.0*torch.ones(node1_logits.size(), device=self.device)) # step4: predict second node: node2 node1_index_per_graph = (graph.num_cum_nodes - graph.num_nodes) + label_dict["label1"] #(num_graph) node1_index = node1_index_per_graph[extended_node2graph] # (num_node + 16 * num_graph) node2_feature_node1 = node1_feature[node1_index] #(num_node + 16 * num_graph, n_out) node2_feature = torch.cat((node2_feature_node1, node2_feature_node2), 1) #(num_node + 16 * num_graph, 2n_out) if not use_agent: node2_logits = self.mlp_node2(node2_feature).squeeze(1) #(num_node + 16 * num_graph) else: node2_logits = self.agent_mlp_node2(node2_feature).squeeze(1) #(num_node + 16 * num_graph) #mask the selected node1 mask = torch.zeros(node2_logits.size(), device=self.device) mask[node1_index_per_graph] = 1 node2_logits = torch.where(mask==0, node2_logits, -10000.0*torch.ones(node2_logits.size(), device=self.device)) # step5: predict edge type is_new_node = label_dict["label2"] - graph.num_nodes # if an entry is non-negative, this is a new added node. (num_graph) graph_offset = torch.arange(graph.num_nodes.size(0), device=self.device) node2_index_per_graph = torch.where(is_new_node >= 0, graph.num_node + graph_offset * self.id2atom.size(0) + is_new_node, label_dict["label2"] + graph.num_cum_nodes - graph.num_nodes) # (num_graph) node2_index = node2_index_per_graph[extended_node2graph] edge_feature_node1 = node2_feature_node2[node1_index_per_graph] #(num_graph, n_out) edge_feature_node2 = node2_feature_node2[node2_index_per_graph] # #(num_graph, n_out) edge_feature = torch.cat((edge_feature_node1, edge_feature_node2), 1) #(num_graph, 2n_out) if not use_agent: edge_logits = self.mlp_edge(edge_feature) # (num_graph, num_relation) else: edge_logits = self.agent_mlp_edge(edge_feature) # (num_graph, num_relation) index_dict = { "node1_index_per_graph": node1_index_per_graph, "node2_index_per_graph": node2_index_per_graph, "extended_node2graph": extended_node2graph } return stop_logits, node1_logits, node2_logits, edge_logits, index_dict def reinforce_forward(self, batch): all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} if self.batch_id % self.agent_update_interval == 0: self.agent_model.load_state_dict(self.model.state_dict()) self.agent_mlp_stop.load_state_dict(self.mlp_stop.state_dict()) self.agent_mlp_node1.load_state_dict(self.mlp_node1.state_dict()) self.agent_mlp_node2.load_state_dict(self.mlp_node2.state_dict()) self.agent_mlp_edge.load_state_dict(self.mlp_edge.state_dict()) self.agent_new_atom_embeddings.data = self.new_atom_embeddings.data.clone() self.batch_id += 1 # generation takes less time when early_stop=True graph = self.generate(len(batch["graph"]), max_resample=20, off_policy=True, max_step=40 * 2, verbose=1) if graph.num_nodes.max() == 1: raise ValueError("Generation results collapse to singleton molecules") reward = torch.zeros(len(graph), device=self.device) for task in self.task: if task == "plogp": plogp = metrics.penalized_logP(graph) metric["Penalized logP"] = plogp.mean() metric["Penalized logP (max)"] = plogp.max() self.update_best_result(graph, plogp, "Penalized logP") # TODO: reward += (plogp / self.reward_temperature).exp() if plogp.max().item() > 5: print("Penalized logP max = %s" % plogp.max().item()) print(self.best_results["Penalized logP"]) elif task == "qed": qed = metrics.QED(graph) metric["QED"] = qed.mean() metric["QED (max)"] = qed.max() self.update_best_result(graph, qed, "QED") # TODO: #reward += ((qed - 0.9) * 20).exp() #reward += ((qed - 0.4) * 4 / self.reward_temperature).exp() #reward += qed reward += (qed / self.reward_temperature).exp() if qed.max().item() > 0.93: print("QED max = %s" % qed.max().item()) print(self.best_results["QED"]) else: raise ValueError("Unknown task `%s`" % task) # these graph-level features will broadcast to all masked graphs with graph.graph(): graph.reward = reward graph.original_num_nodes = graph.num_nodes #graph.atom_type = self.atom2id[graph.atom_type] is_training = self.training # easily got nan if BN is trained self.bn_eval() stop_graph, stop_label1, stop_label2, stop_label3, stop_label4 = self.all_stop(graph) edge_graph, edge_label1, edge_label2, edge_label3, edge_label4 = self.all_edge(graph) graph = self._cat([stop_graph, edge_graph]) label1_target = torch.cat([stop_label1, edge_label1]) label2_target = torch.cat([stop_label2, edge_label2]) label3_target = torch.cat([stop_label3, edge_label3]) label4_target = torch.cat([stop_label4, edge_label4]) label_dict = {"label1": label1_target, "label2": label2_target, "label3": label3_target, "label4": label4_target} # reward reshaping reward = graph.reward reward = reward * self.gamma ** (graph.original_num_nodes - graph.num_nodes).float() # per graph size reward baseline weight = torch.ones_like(graph.num_nodes, dtype=torch.float) baseline = scatter_add(reward, graph.num_nodes, dim_size=self.max_node + 1) / \ (scatter_add(weight, graph.num_nodes, dim_size=self.max_node + 1) + self.eps) # TODO: self.moving_baseline = self.moving_baseline * self.baseline_momentum + baseline * (1 - self.baseline_momentum) reward -= self.moving_baseline[graph.num_nodes] reward += graph.is_valid # calculate object stop_logits, node1_logits, node2_logits, edge_logits, index_dict = self.predict(graph, label_dict) with torch.no_grad(): old_stop_logits, old_node1_logits, old_node2_logits, old_edge_logits, old_index_dict = self.predict(graph, label_dict, use_agent=True) stop_prob = F.log_softmax(stop_logits, dim=-1) node1_prob = scatter_log_softmax(node1_logits, index_dict["extended_node2graph"]) node2_prob = scatter_log_softmax(node2_logits, index_dict["extended_node2graph"]) edge_prob = F.log_softmax(edge_logits, dim=-1) old_stop_prob = F.log_softmax(old_stop_logits, dim=-1) old_node1_prob = scatter_log_softmax(old_node1_logits, old_index_dict["extended_node2graph"]) old_node2_prob = scatter_log_softmax(old_node2_logits, old_index_dict["extended_node2graph"]) old_edge_prob = F.log_softmax(old_edge_logits, dim=-1) cur_logp = stop_prob[:, 0] + node1_prob[index_dict["node1_index_per_graph"]] \ + node2_prob[index_dict["node2_index_per_graph"]] + torch.gather(edge_prob, -1, label3_target.view(-1, 1)).view(-1) cur_logp[label4_target==1] = stop_prob[:, 1][label4_target==1] old_logp = old_stop_prob[:, 0] + old_node1_prob[old_index_dict["node1_index_per_graph"]] \ + old_node2_prob[index_dict["node2_index_per_graph"]] + torch.gather(old_edge_prob, -1, label3_target.view(-1, 1)).view(-1) old_logp[label4_target==1] = old_stop_prob[:, 1][label4_target==1] objective = functional.clipped_policy_gradient_objective(cur_logp, old_logp, reward) objective = objective.mean() metric["PPO objective"] = objective all_loss += (-objective) self.bn_train(is_training) return all_loss, metric def MLE_forward(self, batch): all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} graph = batch["graph"] stop_graph, stop_label1, stop_label2, stop_label3, stop_label4 = self.all_stop(graph) edge_graph, edge_label1, edge_label2, edge_label3, edge_label4 = self.all_edge(graph) graph = self._cat([stop_graph, edge_graph]) label1_target = torch.cat([stop_label1, edge_label1]) label2_target = torch.cat([stop_label2, edge_label2]) label3_target = torch.cat([stop_label3, edge_label3]) label4_target = torch.cat([stop_label4, edge_label4]) label_dict = {"label1": label1_target, "label2": label2_target, "label3": label3_target, "label4": label4_target} stop_logits, node1_logits, node2_logits, edge_logits, index_dict = self.predict(graph, label_dict) loss_stop = F.nll_loss(F.log_softmax(stop_logits, dim=-1), label4_target, reduction='none') loss_stop = 0.5 * (torch.mean(loss_stop[label4_target==0]) + torch.mean(loss_stop[label4_target==1])) #loss_stop = torch.mean(loss_stop) metric["stop bce loss"] = loss_stop all_loss += loss_stop loss_node1 = -(scatter_log_softmax(node1_logits, index_dict["extended_node2graph"])[index_dict["node1_index_per_graph"]]) loss_node1 = torch.mean(loss_node1[label4_target==0]) metric["node1 loss"] = loss_node1 all_loss += loss_node1 loss_node2 = -(scatter_log_softmax(node2_logits, index_dict["extended_node2graph"])[index_dict["node2_index_per_graph"]]) loss_node2 = torch.mean(loss_node2[label4_target==0]) metric["node2 loss"] = loss_node2 all_loss += loss_node2 loss_edge = F.nll_loss(F.log_softmax(edge_logits, dim=-1), label3_target, reduction='none') loss_edge = torch.mean(loss_edge[label4_target==0]) metric["edge loss"] = loss_edge all_loss += loss_edge metric["total loss"] = all_loss pred = stop_logits, node1_logits, node2_logits, edge_logits target = label1_target, label2_target, label3_target, label4_target, index_dict metric.update(self.evaluate(pred, target)) return all_loss, metric def evaluate(self, pred, target): stop_logits, node1_logits, node2_logits, edge_logits = pred label1_target, label2_target, label3_target, label4_target, index_dict = target metric = {} stop_acc = torch.argmax(stop_logits, -1) == label4_target metric["stop acc"] = stop_acc.float().mean() node1_pred = scatter_max(node1_logits, index_dict["extended_node2graph"])[1] node1_acc = node1_pred == index_dict["node1_index_per_graph"] metric["node1 acc"] = node1_acc[label4_target==0].float().mean() node2_pred = scatter_max(node2_logits, index_dict["extended_node2graph"])[1] node2_acc = node2_pred == index_dict["node2_index_per_graph"] metric["node2 acc"] = node2_acc[label4_target==0].float().mean() edge_acc = torch.argmax(edge_logits, -1) == label3_target metric["edge acc"] = edge_acc[label4_target==0].float().mean() return metric # generation step # 1. top-1 action # 2. apply action @torch.no_grad() def _construct_dist(self, prob_, graph): max_size = max(graph.num_nodes) + self.id2atom.size(0) probs = torch.zeros((len(graph.num_nodes), max_size),device=prob_.device).view(-1) #start = torch.arange(graph.num_nodes.size(0), device=self.device) * max_size start = (graph.num_cum_nodes - graph.num_nodes)[graph.node2graph] start = torch.arange(graph.num_node, device=self.device) - start index = torch.arange(graph.num_nodes.size(0), device=self.device) * max_size index = index[graph.node2graph] + start probs[index] = prob_[:graph.num_node] start_extend = torch.arange(self.id2atom.size(0), device=self.device).repeat(graph.num_nodes.size()) # (num_graph * 16) index_extend = torch.arange(graph.num_nodes.size(0), device=self.device) * max_size + graph.num_nodes index2graph = torch.arange(graph.num_nodes.size(0), device=self.device).unsqueeze(1).repeat(1, self.id2atom.size(0)).view(-1) index_extend = index_extend[index2graph] + start_extend probs[index_extend] = prob_[graph.num_node:] probs = probs.view(len(graph.num_nodes), max_size) return torch.distributions.Categorical(probs), probs # (n_graph, max_size) @torch.no_grad() def _sample_action(self, graph, off_policy): if off_policy: model = self.agent_model new_atom_embeddings = self.agent_new_atom_embeddings mlp_stop = self.agent_mlp_stop mlp_node1 = self.agent_mlp_node1 mlp_node2 = self.agent_mlp_node2 mlp_edge = self.agent_mlp_edge else: model = self.model new_atom_embeddings = self.new_atom_embeddings mlp_stop = self.mlp_stop mlp_node1 = self.mlp_node1 mlp_node2 = self.mlp_node2 mlp_edge = self.mlp_edge # step1: get feature output = model(graph, graph.node_feature.float()) extended_node2graph = torch.arange(graph.num_nodes.size(0), device=self.device).unsqueeze(1).repeat([1, self.id2atom.size(0)]).view(-1) # (num_graph * 16) extended_node2graph = torch.cat((graph.node2graph, extended_node2graph)) # (num_node + 16 * num_graph) graph_feature_per_node = output["graph_feature"][extended_node2graph] # step2: predict stop stop_feature = output["graph_feature"] #(num_graph, n_out) stop_logits = mlp_stop(stop_feature) #(num_graph, 2) stop_prob = F.softmax(stop_logits, -1) #(num_graph, 2) #print('stop_prob', stop_prob) stop_prob_dist = torch.distributions.Categorical(stop_prob) stop_pred = stop_prob_dist.sample() #stop_pred = torch.argmax(stop_logits, -1) # (num_graph) # step3: predict first node: node1 node1_feature = output["node_feature"] #(num_node, n_out) node1_feature = torch.cat((node1_feature, new_atom_embeddings.repeat([graph.num_nodes.size(0), 1])), 0) # (num_node + 16 * num_graph, n_out) node2_feature_node2 = node1_feature.clone() # (num_node + 16 * num_graph, n_out) node1_feature = torch.cat((node1_feature, graph_feature_per_node), 1) node1_logits = mlp_node1(node1_feature).squeeze(1) #(num_node + 16 * num_graph) #mask the extended part mask = torch.zeros(node1_logits.size(), device=self.device) mask[:graph.num_node] = 1 node1_logits = torch.where(mask>0, node1_logits, -10000.0*torch.ones(node1_logits.size(), device=self.device)) node1_prob = scatter_softmax(node1_logits, extended_node2graph) # (num_node + 16 * num_graph) node1_prob_dist, tmp = self._construct_dist(node1_prob, graph) #(num_graph, max) #print('node1', tmp) node1_pred = node1_prob_dist.sample() #(num_graph) node1_index_per_graph = node1_pred + (graph.num_cum_nodes - graph.num_nodes) # step4: predict second node: node2 node1_index = node1_index_per_graph[extended_node2graph] # (num_node + 16 * num_graph) node2_feature_node1 = node1_feature[node1_index] #(num_node + 16 * num_graph, n_out) node2_feature = torch.cat((node2_feature_node1, node2_feature_node2), 1) #(num_node + 16 * num_graph, 2n_out) node2_logits = mlp_node2(node2_feature).squeeze(1) #(num_node + 16 * num_graph) #mask the selected node1 mask = torch.zeros(node2_logits.size(), device=self.device) mask[node1_index_per_graph] = 1 node2_logits = torch.where(mask==0, node2_logits, -10000.0*torch.ones(node2_logits.size(), device=self.device)) node2_prob = scatter_softmax(node2_logits, extended_node2graph) # (num_node + 16 * num_graph) node2_prob_dist, tmp = self._construct_dist(node2_prob, graph) #(num_graph, max) #print('node2', tmp) node2_pred = node2_prob_dist.sample() #(num_graph) is_new_node = node2_pred - graph.num_nodes graph_offset = torch.arange(graph.num_nodes.size(0), device=self.device) node2_index_per_graph = torch.where(is_new_node >= 0, graph.num_node + graph_offset * self.id2atom.size(0) + is_new_node, node2_pred + graph.num_cum_nodes - graph.num_nodes) # step5: predict edge type edge_feature_node1 = node2_feature_node2[node1_index_per_graph] #(num_graph, n_out) edge_feature_node2 = node2_feature_node2[node2_index_per_graph] # #(num_graph, n_out) edge_feature = torch.cat((edge_feature_node1, edge_feature_node2), 1) #(num_graph, 2n_out) edge_logits = mlp_edge(edge_feature) edge_prob = F.softmax(edge_logits, -1) #(num_graph, 3) #print('edge', edge_prob) edge_prob_dist = torch.distributions.Categorical(edge_prob) edge_pred = edge_prob_dist.sample() return stop_pred, node1_pred, node2_pred, edge_pred @torch.no_grad() def _top1_action(self, graph, off_policy): if off_policy: model = self.agent_model new_atom_embeddings = self.agent_new_atom_embeddings mlp_stop = self.agent_mlp_stop mlp_node1 = self.agent_mlp_node1 mlp_node2 = self.agent_mlp_node2 mlp_edge = self.agent_mlp_edge else: model = self.model new_atom_embeddings = self.new_atom_embeddings mlp_stop = self.mlp_stop mlp_node1 = self.mlp_node1 mlp_node2 = self.mlp_node2 mlp_edge = self.mlp_edge # step1: get feature output = model(graph, graph.node_feature.float()) extended_node2graph = torch.arange(graph.num_nodes.size(0), device=self.device).unsqueeze(1).repeat([1, self.id2atom.size(0)]).view(-1) # (num_graph * 16) extended_node2graph = torch.cat((graph.node2graph, extended_node2graph)) # (num_node + 16 * num_graph) graph_feature_per_node = output["graph_feature"][extended_node2graph] # step2: predict stop stop_feature = output["graph_feature"] #(num_graph, n_out) stop_logits = mlp_stop(stop_feature) #(num_graph, 2) stop_pred = torch.argmax(stop_logits, -1) # (num_graph) #stop_prob = F.softmax(stop_logits, -1) #(num_graph, 2) #print('stop_prob', stop_prob) #stop_prob_dist = torch.distributions.Categorical(stop_prob) #stop_pred = stop_prob_dist.sample() #stop_pred = torch.argmax(stop_logits, -1) # (num_graph) # step3: predict first node: node1 node1_feature = output["node_feature"] #(num_node, n_out) node1_feature = torch.cat((node1_feature, new_atom_embeddings.repeat([graph.num_nodes.size(0), 1])), 0) # (num_node + 16 * num_graph, n_out) node2_feature_node2 = node1_feature.clone() # (num_node + 16 * num_graph, n_out) node1_feature = torch.cat((node1_feature, graph_feature_per_node), 1) node1_logits = mlp_node1(node1_feature).squeeze(1) #(num_node + 16 * num_graph) #mask the extended part mask = torch.zeros(node1_logits.size(), device=self.device) mask[:graph.num_node] = 1 node1_logits = torch.where(mask>0, node1_logits, -10000.0*torch.ones(node1_logits.size(), device=self.device)) node1_index_per_graph = scatter_max(node1_logits, extended_node2graph)[1] # (num_node + 16 * num_graph) node1_pred = node1_index_per_graph - (graph.num_cum_nodes - graph.num_nodes) #node1_prob = scatter_softmax(node1_logits, extended_node2graph) # (num_node + 16 * num_graph) #node1_prob_dist, tmp = self._construct_dist(node1_prob, graph) #(num_graph, max) #print('node1', tmp) #node1_pred = node1_prob_dist.sample() #(num_graph) #node1_index_per_graph = node1_pred + (graph.num_cum_nodes - graph.num_nodes) # step4: predict second node: node2 node1_index = node1_index_per_graph[extended_node2graph] # (num_node + 16 * num_graph) node2_feature_node1 = node1_feature[node1_index] #(num_node + 16 * num_graph, n_out) node2_feature = torch.cat((node2_feature_node1, node2_feature_node2), 1) #(num_node + 16 * num_graph, 2n_out) node2_logits = mlp_node2(node2_feature).squeeze(1) #(num_node + 16 * num_graph) #mask the selected node1 mask = torch.zeros(node2_logits.size(), device=self.device) mask[node1_index_per_graph] = 1 node2_logits = torch.where(mask==0, node2_logits, -10000.0*torch.ones(node2_logits.size(), device=self.device)) node2_index_per_graph = scatter_max(node2_logits, extended_node2graph)[1] # (num_node + 16 * num_graph) is_new_node = node2_index_per_graph - graph.num_node # non negative if is new node #node2_prob = scatter_softmax(node2_logits, extended_node2graph) # (num_node + 16 * num_graph) #node2_prob_dist, tmp = self._construct_dist(node2_prob, graph) #(num_graph, max) #print('node2', tmp) #node2_pred = node2_prob_dist.sample() #(num_graph) #is_new_node = node2_pred - graph.num_nodes graph_offset = torch.arange(graph.num_nodes.size(0), device=self.device) node2_pred = torch.where(is_new_node>=0, graph.num_nodes + is_new_node - graph_offset * self.id2atom.size(0), node2_index_per_graph - (graph.num_cum_nodes - graph.num_nodes)) #node2_index_per_graph = torch.where(is_new_node >= 0, # graph.num_node + graph_offset * self.id2atom.size(0) + is_new_node, # node2_pred + graph.num_cum_nodes - graph.num_nodes) # step5: predict edge type edge_feature_node1 = node2_feature_node2[node1_index_per_graph] #(num_graph, n_out) edge_feature_node2 = node2_feature_node2[node2_index_per_graph] # #(num_graph, n_out) edge_feature = torch.cat((edge_feature_node1, edge_feature_node2), 1) #(num_graph, 2n_out) edge_logits = mlp_edge(edge_feature) #edge_prob = F.softmax(edge_logits, -1) #(num_graph, 3) #print('edge', edge_prob) #edge_prob_dist = torch.distributions.Categorical(edge_prob) #edge_pred = edge_prob_dist.sample() edge_pred = torch.argmax(edge_logits, -1) return stop_pred, node1_pred, node2_pred, edge_pred @torch.no_grad() def _apply_action(self, graph, off_policy, max_resample=10, verbose=0, min_node=5): # action (num_graph, 4) # stopped graph is removed, initialize is_valid as False is_valid = torch.zeros(len(graph), dtype=torch.bool, device=self.device) stop_action = torch.zeros(len(graph), dtype=torch.long, device=self.device) node1_action = torch.zeros(len(graph), dtype=torch.long, device=self.device) node2_action = torch.zeros(len(graph), dtype=torch.long, device=self.device) edge_action = torch.zeros(len(graph), dtype=torch.long, device=self.device) for i in range(max_resample): # maximal resample time mask = ~is_valid if max_resample == 1: tmp_stop_action, tmp_node1_action, tmp_node2_action, tmp_edge_action = self._top1_action(graph, off_policy) else: tmp_stop_action, tmp_node1_action, tmp_node2_action, tmp_edge_action = self._sample_action(graph, off_policy) stop_action[mask] = tmp_stop_action[mask] node1_action[mask] = tmp_node1_action[mask] node2_action[mask] = tmp_node2_action[mask] edge_action[mask] = tmp_edge_action[mask] stop_action[graph.num_nodes <= 5] = torch.zeros(len(graph), dtype=torch.long, device=self.device)[graph.num_nodes <= 5] # tmp add new nodes has_new_node = (node2_action >= graph.num_nodes) & (stop_action == 0) new_atom_id = (node2_action - graph.num_nodes)[has_new_node] new_atom_type = self.id2atom[new_atom_id] atom_type, num_nodes = functional._extend(graph.atom_type, graph.num_nodes, new_atom_type, has_new_node) # tmp cast to regular node ids node2_action = torch.where(has_new_node, graph.num_nodes, node2_action) # tmp modify edges new_edge = torch.stack([node1_action, node2_action], dim=-1) edge_list = graph.edge_list.clone() bond_type = graph.bond_type.clone() edge_list[:, :2] -= graph._offsets.unsqueeze(-1) is_modified_edge = (edge_list[:, :2] == new_edge[graph.edge2graph]).all(dim=-1) & \ (stop_action[graph.edge2graph] == 0) has_modified_edge = scatter_max(is_modified_edge.long(), graph.edge2graph, dim_size=len(graph))[0] > 0 bond_type[is_modified_edge] = edge_action[has_modified_edge] edge_list[is_modified_edge, 2] = edge_action[has_modified_edge] # tmp modify reverse edges new_edge = new_edge.flip(-1) is_modified_edge = (edge_list[:, :2] == new_edge[graph.edge2graph]).all(dim=-1) & \ (stop_action[graph.edge2graph] == 0) bond_type[is_modified_edge] = edge_action[has_modified_edge] edge_list[is_modified_edge, 2] = edge_action[has_modified_edge] # tmp add new edges has_new_edge = (~has_modified_edge) & (stop_action == 0) new_edge_list = torch.stack([node1_action, node2_action, edge_action], dim=-1)[has_new_edge] bond_type = functional._extend(bond_type, graph.num_edges, edge_action[has_new_edge], has_new_edge)[0] edge_list, num_edges = functional._extend(edge_list, graph.num_edges, new_edge_list, has_new_edge) # tmp add reverse edges new_edge_list = torch.stack([node2_action, node1_action, edge_action], dim=-1)[has_new_edge] bond_type = functional._extend(bond_type, num_edges, edge_action[has_new_edge], has_new_edge)[0] edge_list, num_edges = functional._extend(edge_list, num_edges, new_edge_list, has_new_edge) tmp_graph = type(graph)(edge_list, atom_type=atom_type, bond_type=bond_type, num_nodes=num_nodes, num_edges=num_edges, num_relation=graph.num_relation) is_valid = tmp_graph.is_valid | (stop_action == 1) if is_valid.all(): break if not is_valid.all() and verbose: num_invalid = len(graph) - is_valid.sum().item() num_working = len(graph) logger.warning("%d / %d molecules are invalid even after %d resampling" % (num_invalid, num_working, max_resample)) # apply the true action # inherit attributes data_dict = graph.data_dict meta_dict = graph.meta_dict for key in ["atom_type", "bond_type"]: data_dict.pop(key) # pad 0 for node / edge attributes for k, v in data_dict.items(): if meta_dict[k] == "node": shape = (len(new_atom_type), *v.shape[1:]) new_data = torch.zeros(shape, dtype=v.dtype, device=self.device) data_dict[k] = functional._extend(v, graph.num_nodes, new_data, has_new_node)[0] if meta_dict[k] == "edge": shape = (len(new_edge_list) * 2, *v.shape[1:]) new_data = torch.zeros(shape, dtype=v.dtype, device=self.device) data_dict[k] = functional._extend(v, graph.num_edges, new_data, has_new_edge * 2)[0] new_graph = type(graph)(edge_list, atom_type=atom_type, bond_type=bond_type, num_nodes=num_nodes, num_edges=num_edges, num_relation=graph.num_relation, meta_dict=meta_dict, **data_dict) with new_graph.graph(): new_graph.is_stopped = stop_action == 1 new_graph, feature_valid = self._update_molecule_feature(new_graph) return new_graph[feature_valid] def _update_molecule_feature(self, graphs): # This function is very slow mols = graphs.to_molecule(ignore_error=True) valid = [mol is not None for mol in mols] valid = torch.tensor(valid, device=graphs.device) new_graphs = type(graphs).from_molecule(mols, kekulize=True, node_feature="symbol") node_feature = torch.zeros(graphs.num_node, *new_graphs.node_feature.shape[1:], dtype=new_graphs.node_feature.dtype, device=graphs.device) edge_feature = torch.zeros(graphs.num_edge, *new_graphs.edge_feature.shape[1:], dtype=new_graphs.edge_feature.dtype, device=graphs.device) bond_type = torch.zeros_like(graphs.bond_type) node_mask = valid[graphs.node2graph] edge_mask = valid[graphs.edge2graph] node_feature[node_mask] = new_graphs.node_feature.to(device=graphs.device) edge_feature[edge_mask] = new_graphs.edge_feature.to(device=graphs.device) bond_type[edge_mask] = new_graphs.bond_type.to(device=graphs.device) with graphs.node(): graphs.node_feature = node_feature with graphs.edge(): graphs.edge_feature = edge_feature graphs.bond_type = bond_type return graphs, valid def bn_train(self, mode=True): for module in self.modules(): if isinstance(module, nn.BatchNorm1d): module.train(mode) def bn_eval(self): for module in self.modules(): if isinstance(module, nn.BatchNorm1d): module.eval() def update_best_result(self, graph, score, task): score = score.cpu() best_results = self.best_results[task] for s, i in zip(*score.sort(descending=True)): s = s.item() i = i.item() if len(best_results) == self.top_k and s < best_results[-1][0]: break best_results.append((s, graph[i].to_smiles())) best_results.sort(reverse=True) best_results = best_results[:self.top_k] self.best_results[task] = best_results @torch.no_grad() def generate(self, num_sample, max_resample=20, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0): is_training = self.training self.eval() graph = data.Molecule.from_smiles(initial_smiles, kekulize=True, node_feature="symbol").repeat(num_sample) # TODO: workaround if self.device.type == "cuda": graph = graph.cuda(self.device) result = [] for i in range(max_step): new_graph = self._apply_action(graph, off_policy, max_resample, verbose=1) if i == max_step - 1: # last step, collect all graph that is valid result.append(new_graph[(new_graph.num_nodes <= (self.max_node))]) else: result.append(new_graph[new_graph.is_stopped | (new_graph.num_nodes == (self.max_node))]) is_continue = (~new_graph.is_stopped) & (new_graph.num_nodes < (self.max_node)) graph = new_graph[is_continue] if len(graph) == 0: break self.train(is_training) result = self._cat(result) return result def _append(self, data, num_xs, input, mask=None): if mask is None: mask = torch.ones_like(num_xs, dtype=torch.bool) new_num_xs = num_xs + mask new_num_cum_xs = new_num_xs.cumsum(0) new_num_x = new_num_cum_xs[-1].item() new_data = torch.zeros(new_num_x, *data.shape[1:], dtype=data.dtype, device=data.device) starts = new_num_cum_xs - new_num_xs ends = starts + num_xs index = functional.multi_slice_mask(starts, ends, new_num_x) new_data[index] = data new_data[~index] = input[mask] return new_data, new_num_xs @torch.no_grad() def all_stop(self, graph): if (graph.num_nodes < 2).any(): graph = graph[graph.num_nodes >= 2] warnings.warn("Graphs with less than 2 nodes can't be used for stop prediction learning. Dropped") label1 = torch.zeros(len(graph), dtype=torch.long, device=self.device) label2 = torch.zeros_like(label1) label3 = torch.zeros_like(label1) return graph, label1, label2, label3, torch.ones(len(graph), dtype=torch.long, device=self.device) @torch.no_grad() def all_edge(self, graph): if (graph.num_nodes < 2).any(): graph = graph[graph.num_nodes >= 2] warnings.warn("Graphs with less than 2 nodes can't be used for edge generation learning. Dropped") lengths = self._valid_edge_prefix_lengths(graph) starts, ends, valid = self._all_prefix_slice(graph.num_nodes ** 2, lengths) num_keep_dense_edges = ends - starts num_repeat = len(starts) // len(graph) graph = graph.repeat(num_repeat) # undirected: all upper triangular edge ids are flipped to lower triangular ids # 1 -> 2, 4 -> 6, 5 -> 7 node_index = graph.edge_list[:, :2] - graph._offsets.unsqueeze(-1) node_in, node_out = node_index.t() node_large = node_index.max(dim=-1)[0] node_small = node_index.min(dim=-1)[0] edge_id = node_large ** 2 + (node_in >= node_out) * node_large + node_small undirected_edge_id = node_large * (node_large + 1) + node_small edge_mask = undirected_edge_id < num_keep_dense_edges[graph.edge2graph] circum_box_size = (num_keep_dense_edges + 1.0).sqrt().ceil().long() # check whether we need to add a new node for the current edge masked_undirected_edge_id = torch.where(edge_mask, undirected_edge_id, -torch.ones(undirected_edge_id.size(), dtype=torch.long, device=graph.device)) current_circum_box_size = scatter_max(masked_undirected_edge_id, graph.edge2graph, dim=0)[0] current_circum_box_size = (current_circum_box_size + 1.0).sqrt().ceil().long() is_new_node_edge = (circum_box_size > current_circum_box_size).long() starts = graph.num_cum_nodes - graph.num_nodes ends = starts + circum_box_size - is_new_node_edge node_mask = functional.multi_slice_mask(starts, ends, graph.num_node) # compact nodes so that succeeding nodes won't affect graph pooling new_graph = graph.edge_mask(edge_mask).node_mask(node_mask, compact=True) positive_edge = edge_id == num_keep_dense_edges[graph.edge2graph] positive_graph = scatter_add(positive_edge.long(), graph.edge2graph, dim=0, dim_size=len(graph)).bool() # default: non-edge target = (self.model.num_relation) * torch.ones(graph.batch_size, dtype=torch.long, device=graph.device) target[positive_graph] = graph.edge_list[:, 2][positive_edge] # node_in > node_out node_in = circum_box_size - 1 node_out = num_keep_dense_edges - node_in * circum_box_size # if we need to add a new node, what will be its atomid? new_node_atomid = self.atom2id[graph.atom_type[starts +node_in]] # keep only the positive graph, as we will add an edge at each step new_graph = new_graph[positive_graph] target = target[positive_graph] node_in = node_in[positive_graph] node_out = node_out[positive_graph] is_new_node_edge = is_new_node_edge[positive_graph] new_node_atomid = new_node_atomid[positive_graph] node_in_extend = new_graph.num_nodes + new_node_atomid node_in_final = torch.where(is_new_node_edge == 0, node_in, node_in_extend) return new_graph, node_out, node_in_final, target, torch.zeros_like(node_out) @torch.no_grad() def _all_prefix_slice(self, num_xs, lengths=None): # extract a bunch of slices that correspond to the following num_repeat * n masks # ------ repeat 0 ----- # graphs[0]: [0, 0, ..., 0] # ... # graphs[-1]: [0, 0, ..., 0] # ------ repeat 1 ----- # graphs[0]: [1, 0, ..., 0] # ... # graphs[-1]: [1, 0, ..., 0] # ... # ------ repeat -1 ----- # graphs[0]: [1, ..., 1, 0] # ... # graphs[-1]: [1, ..., 1, 0] num_cum_xs = num_xs.cumsum(0) starts = num_cum_xs - num_xs if lengths is None: num_max_x = num_xs.max().item() lengths = torch.arange(num_max_x, device=num_xs.device) pack_offsets = torch.arange(len(lengths), device=num_xs.device) * num_cum_xs[-1] # starts, lengths, ends: (num_repeat, num_graph) starts = starts.unsqueeze(0) + pack_offsets.unsqueeze(-1) valid = lengths.unsqueeze(-1) <= num_xs.unsqueeze(0) - 1 lengths = torch.min(lengths.unsqueeze(-1), num_xs.unsqueeze(0) - 1).clamp(0) ends = starts + lengths starts = starts.flatten() ends = ends.flatten() valid = valid.flatten() return starts, ends, valid @torch.no_grad() def _valid_edge_prefix_lengths(self, graph): num_max_node = graph.num_nodes.max().item() # edge id in an adjacency (snake pattern) # in # o 0 1 4 # u 2 3 5 # t 6 7 8 lengths = torch.arange(num_max_node ** 2, device=graph.device) circum_box_size = (lengths + 1.0).sqrt().ceil().long() # only keep lengths that ends in the lower triangular part of adjacency matrix lengths = lengths[lengths >= circum_box_size * (circum_box_size - 1)] # lengths: [0, 2, 3, 6, 7, 8, ...] # num_node2length_idx: [0, 1, 4, 6, ...] # num_edge_unrolls # 0 # 1 0 # 2 1 0 num_edge_unrolls = (lengths + 1.0).sqrt().ceil().long() ** 2 - lengths - 1 # num_edge_unrolls: [0, 1, 0, 2, 1, 0, ...] # remove lengths that unroll too much. they always lead to empty targets. lengths = lengths[(num_edge_unrolls <= self.max_edge_unroll) & (num_edge_unrolls > 0)] return lengths def _cat(self, graphs): for i, graph in enumerate(graphs): if not isinstance(graph, data.PackedGraph): graphs[i] = graph.pack([graph]) edge_list = torch.cat([graph.edge_list for graph in graphs]) pack_num_nodes = torch.stack([graph.num_node for graph in graphs]) pack_num_edges = torch.stack([graph.num_edge for graph in graphs]) pack_num_cum_edges = pack_num_edges.cumsum(0) graph_index = pack_num_cum_edges < len(edge_list) pack_offsets = scatter_add(pack_num_nodes[graph_index], pack_num_cum_edges[graph_index], dim_size=len(edge_list)) pack_offsets = pack_offsets.cumsum(0) edge_list[:, :2] += pack_offsets.unsqueeze(-1) offsets = torch.cat([graph._offsets for graph in graphs]) + pack_offsets edge_weight = torch.cat([graph.edge_weight for graph in graphs]) num_nodes = torch.cat([graph.num_nodes for graph in graphs]) num_edges = torch.cat([graph.num_edges for graph in graphs]) num_relation = graphs[0].num_relation assert all(graph.num_relation == num_relation for graph in graphs) # only keep attributes that exist in all graphs keys = set(graphs[0].meta_dict.keys()) for graph in graphs: keys = keys.intersection(graph.meta_dict.keys()) meta_dict = {k: graphs[0].meta_dict[k] for k in keys} data_dict = {} for k in keys: data_dict[k] = torch.cat([graph.data_dict[k] for graph in graphs]) return type(graphs[0])(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation, offsets=offsets, meta_dict=meta_dict, **data_dict)