import torch
from torch import nn
from torch.nn import functional as F
from torchdrug import layers
[docs]class ConditionalFlow(nn.Module):
"""
Conditional flow transformation from `Masked Autoregressive Flow for Density Estimation`_.
.. _Masked Autoregressive Flow for Density Estimation:
https://arxiv.org/pdf/1705.07057.pdf
Parameters:
input_dim (int): input & output dimension
condition_dim (int): condition dimension
hidden_dims (list of int, optional): hidden dimensions
activation (str or function, optional): activation function
"""
def __init__(self, input_dim, condition_dim, hidden_dims=None, activation="relu"):
super(ConditionalFlow, self).__init__()
self.input_dim = input_dim
self.output_dim = input_dim
if hidden_dims is None:
hidden_dims = []
self.mlp = layers.MLP(condition_dim, list(hidden_dims) + [input_dim * 2], activation)
self.rescale = nn.Parameter(torch.zeros(1))
[docs] def forward(self, input, condition):
"""
Transform data into latent representations.
Parameters:
input (Tensor): input representations
condition (Tensor): conditional representations
Returns:
(Tensor, Tensor): latent representations, log-likelihood of the transformation
"""
scale, bias = self.mlp(condition).chunk(2, dim=-1)
scale = (F.tanh(scale) * self.rescale)
output = (input + bias) * scale.exp()
log_det = scale
return output, log_det
[docs] def reverse(self, latent, condition):
"""
Transform latent representations into data.
Parameters:
latent (Tensor): latent representations
condition (Tensor): conditional representations
Returns:
(Tensor, Tensor): input representations, log-likelihood of the transformation
"""
scale, bias = self.mlp(condition).chunk(2, dim=-1)
scale = (F.tanh(scale) * self.rescale)
output = latent / scale.exp() - bias
log_det = scale
return output, log_det