Source code for torchdrug.models.flow

import torch
from torch import nn

from torchdrug import core, layers
from torchdrug.layers import functional
from torchdrug.core import Registry as R

[docs]@R.register("models.GraphAF") class GraphAutoregressiveFlow(nn.Module, core.Configurable): """ Graph autoregressive flow proposed in `GraphAF: a Flow-based Autoregressive Model for Molecular Graph Generation`_. .. _GraphAF\: a Flow-based Autoregressive Model for Molecular Graph Generation: Parameters: model (nn.Module): graph representation model prior (nn.Module): prior distribution use_edge (bool, optional): use edge or not num_flow_layer (int, optional): number of conditional flow layers num_mlp_layer (int, optional): number of MLP layers in each conditional flow dequantization_noise (float, optional): scale of dequantization noise """ def __init__(self, model, prior, use_edge=False, num_layer=6, num_mlp_layer=2, dequantization_noise=0.9): super(GraphAutoregressiveFlow, self).__init__() self.model = model self.prior = prior self.use_edge = use_edge self.input_dim = self.output_dim = prior.dim self.dequantization_noise = dequantization_noise assert dequantization_noise < 1 self.layers = nn.ModuleList() for i in range(num_layer): condition_dim = model.output_dim * (3 if use_edge else 1) self.layers.append(layers.ConditionalFlow(self.input_dim, condition_dim, [model.output_dim] * (num_mlp_layer - 1))) def _standarize_edge(self, graph, edge): if edge is not None: edge = edge.clone() if (edge[:, :2] >= graph.num_nodes.unsqueeze(-1)).any(): raise ValueError("Edge index exceeds the number of nodes in the graph") edge[:, :2] += (graph.num_cum_nodes - graph.num_nodes).unsqueeze(-1) return edge
[docs] def forward(self, graph, input, edge=None, all_loss=None, metric=None): """ Compute the log-likelihood for the input given the graph(s). Parameters: graph (Graph): :math:`n` graph(s) input (Tensor): discrete data of shape :math:`(n,)` edge (Tensor, optional): edge list of shape :math:`(n, 2)`. If specified, additionally condition on the edge for each input. all_loss (Tensor, optional): if specified, add loss to this tensor metric (dict, optional): if specified, output metrics to this dict """ if self.use_edge and edge is None: raise ValueError("`use_edge` is true, but no edge is provided") edge = self._standarize_edge(graph, edge) node_feature = functional.one_hot(graph.atom_type, self.model.input_dim) feature = self.model(graph, node_feature, all_loss, metric) node_feature = feature["node_feature"] graph_feature = feature["graph_feature"] if self.use_edge: condition =[node_feature[edge], graph_feature.unsqueeze(1)], dim=1).flatten(1) else: condition = graph_feature x = functional.one_hot(input, self.input_dim) x = x + self.dequantization_noise * torch.rand_like(x) log_dets = [] for layer in self.layers: x, log_det = layer(x, condition) log_dets.append(log_det) log_prior = self.prior(x) log_det = torch.stack(log_dets).sum(dim=0) log_likelihood = log_prior + log_det log_likelihood = log_likelihood.sum(dim=-1) return log_likelihood # (batch_size,)
[docs] def sample(self, graph, edge=None, all_loss=None, metric=None): """ Sample discrete data based on the given graph(s). Parameters: graph (Graph): :math:`n` graph(s) edge (Tensor, optional): edge list of shape :math:`(n, 2)`. If specified, additionally condition on the edge for each input. all_loss (Tensor, optional): if specified, add loss to this tensor metric (dict, optional): if specified, output metrics to this dict """ if self.use_edge and edge is None: raise ValueError("`use_edge` is true, but no edge is provided") edge = self._standarize_edge(graph, edge) node_feature = functional.one_hot(graph.atom_type, self.model.input_dim) feature = self.model(graph, node_feature, all_loss, metric) node_feature = feature["node_feature"] graph_feature = feature["graph_feature"] if self.use_edge: condition =[node_feature[edge], graph_feature.unsqueeze(1)], dim=1).flatten(1) else: condition = graph_feature x = self.prior.sample(len(graph)) for layer in self.layers[::-1]: x, log_det = layer.reverse(x, condition) output = x.argmax(dim=-1) return output # (batch_size,)