Skip to content

kaijennissen/Normalizing_Flows

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

40 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Normalizing Flows

Real NVP Example Example of a normalizing flow (Real NVP) learning the distribution on the left.

Intro

Normalizing flows operate by pushing a simple density through a series of transformations to produce a richer, potentially more multi-modal distribution. -- Papamakarios et al. 2021

These transformations have to be bijective, differentiable with a differentible inverse and a functional determinant, in short is a diffeomorphism (Note that in the NF literature the terms Bijector and diffeomorphism are used interchangably).

Building a custom Bijector with distrax

We start with a linear map given by:

with inverse :

and functional determinant .

In distrax we can construct the above map by subclassing the Bijector class.

import distrax
import jax.numpy as jnp

class OrthogonalProjection2D(distrax.Bijector):
    def __init__(self, theta):
        super().__init__(event_ndims_in=1, event_ndims_out=1)
        self.thetas = theta
        self.sin_theta = jnp.sin(theta)
        self.cos_theta = jnp.cos(theta)
        self.R = jnp.array(
            [[self.cos_theta, -self.sin_theta], [self.sin_theta, self.cos_theta]]
        ).T

    def forward(self, x):
        return jnp.matmul(x, self.R)

    def inverse(self, x):
        return jnp.matmul(x, self.R.T)

    def forward_and_log_det(self, x):
        y = self.forward(x)
        logdet = 1
        return y, logdet

    def inverse_and_log_det(self, x):
        y = self.inverse(x)
        logdet = 1
        return y, logdet

Transforming an independent multivariate Gaussian distribution with the OrthogonalProjection2D for yields a multivariate Gaussian distribution which is no longer independent, as can be seen below: Rotation Bijector Since the above bijector is linear we already knew that where .

In the image below we chained shift, scale and the Orthogonal Projector. On the left hand side the true distribution is depicted and on the right hand side the inferred ditribution using maximum likelihood for the shift parameter $a$, the scale parameter $b$ and the rotation parameter $\theta$: Rotation Bijector 2

MAF & Real NVP & Glow

samples from toy densities (left) and the inferred distribution with MAF (middle) and Real NVP (right)

True MAF Real NVP
MAF MAF MAF
MAF MAF MAF
MAF MAF MAF
MAF MAF MAF
MAF MAF MAF
MAF MAF MAF

Bernstein Flows

Univariate

Bernstein Flow

Independent Multivariate

Bernstein Flow

Multiplicative Normalizing Flows

HINT: Hierarchical Invertible Neural Transport

Reference

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages