Example of a normalizing flow (Real NVP) learning the distribution on the left.
Normalizing flows operate by pushing a simple density through a series of transformations to produce a richer, potentially more multi-modal distribution. -- Papamakarios et al. 2021
These transformations have to be bijective, differentiable with a differentible inverse and a functional determinant, in short is a diffeomorphism (Note that in the NF literature the terms Bijector
and diffeomorphism
are used interchangably).
We start with a linear map given by:
In distrax we can construct the above map by subclassing the Bijector
class.
import distrax
import jax.numpy as jnp
class OrthogonalProjection2D(distrax.Bijector):
def __init__(self, theta):
super().__init__(event_ndims_in=1, event_ndims_out=1)
self.thetas = theta
self.sin_theta = jnp.sin(theta)
self.cos_theta = jnp.cos(theta)
self.R = jnp.array(
[[self.cos_theta, -self.sin_theta], [self.sin_theta, self.cos_theta]]
).T
def forward(self, x):
return jnp.matmul(x, self.R)
def inverse(self, x):
return jnp.matmul(x, self.R.T)
def forward_and_log_det(self, x):
y = self.forward(x)
logdet = 1
return y, logdet
def inverse_and_log_det(self, x):
y = self.inverse(x)
logdet = 1
return y, logdet
Transforming an independent multivariate Gaussian distribution with the OrthogonalProjection2D
for yields a multivariate Gaussian distribution which is no longer independent, as can be seen below:
Since the above bijector is linear we already knew that where .
In the image below we chained shift, scale and the Orthogonal Projector.
On the left hand side the true distribution is depicted and on the right hand side the inferred ditribution using maximum likelihood for the shift parameter
samples from toy densities (left) and the inferred distribution with MAF (middle) and Real NVP (right)
True | MAF | Real NVP |
---|---|---|
- Normalizing Flows for Probabilistic Modeling and Inference George Papamakarios, Eric Nalisnick, Danilo Jimenez Rezende, Shakir Mohamed, Balaji Lakshminarayanan
- Density estimation using Real NVP Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio
- Deep transformation models: Tackling complex regression problems with neural network based transformation models Beate Sick, Torsten Hothorn, Oliver Dürr
- Robust normalizing flows using Bernstein-type polynomials Sameera Ramasinghe, Kasun Fernando, Salman Khan, Nick Barnes
- Building custom bijectors with Tensorflow Probability