Skip to content

Commit

Permalink
Use multipledispatch for kl_registry (#1252)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Dec 13, 2021
1 parent 7c8dd4e commit 545db3f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 97 deletions.
4 changes: 3 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ ipython
jax>=0.2.11
jaxlib>=0.1.62
jaxns>=0.0.7
optax>=0.0.6
multipledispatch
nbsphinx==0.8.6
numpy
optax>=0.0.6
readthedocs-sphinx-search==0.1.0
sphinx==4.0.3
sphinx-gallery
Expand Down
119 changes: 23 additions & 96 deletions numpyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from functools import total_ordering
import warnings
from multipledispatch import dispatch

from jax import lax
import jax.numpy as jnp
Expand All @@ -42,124 +41,52 @@
)
from numpyro.distributions.util import scale_and_mask, sum_rightmost

_KL_REGISTRY = (
{}
) # Source of truth mapping a few general (type, type) pairs to functions.
_KL_MEMOIZE = (
{}
) # Memoized version mapping many specific (type, type) pairs to functions.


def register_kl(type_p, type_q):
if not isinstance(type_p, type) and issubclass(type_p, Distribution):
raise TypeError(
"Expected type_p to be a Distribution subclass but got {}".format(type_p)
)
if not isinstance(type_q, type) and issubclass(type_q, Distribution):
raise TypeError(
"Expected type_q to be a Distribution subclass but got {}".format(type_q)
)

def decorator(fun):
_KL_REGISTRY[type_p, type_q] = fun
_KL_MEMOIZE.clear() # reset since lookup order may have changed
return fun

return decorator


@total_ordering
class _Match(object):
__slots__ = ["types"]

def __init__(self, *types):
self.types = types

def __eq__(self, other):
return self.types == other.types

def __le__(self, other):
for x, y in zip(self.types, other.types):
if not issubclass(x, y):
return False
if x is not y:
break
return True


def _dispatch_kl(type_p, type_q):
"""
Find the most specific approximate match, assuming single inheritance.
"""
matches = [
(super_p, super_q)
for super_p, super_q in _KL_REGISTRY
if issubclass(type_p, super_p) and issubclass(type_q, super_q)
]
if not matches:
return NotImplemented
# Check that the left- and right- lexicographic orders agree.
left_p, left_q = min(_Match(*m) for m in matches).types
right_q, right_p = min(_Match(*reversed(m)) for m in matches).types
left_fun = _KL_REGISTRY[left_p, left_q]
right_fun = _KL_REGISTRY[right_p, right_q]
if left_fun is not right_fun:
warnings.warn(
"Ambiguous kl_divergence({}, {}). Please register_kl({}, {})".format(
type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__
),
RuntimeWarning,
)
return left_fun


def kl_divergence(p, q):
r"""
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
"""
try:
fun = _KL_MEMOIZE[type(p), type(q)]
except KeyError:
fun = _dispatch_kl(type(p), type(q))
_KL_MEMOIZE[type(p), type(q)] = fun
if fun is NotImplemented:
raise NotImplementedError
return fun(p, q)
raise NotImplementedError


################################################################################
# KL Divergence Implementations
################################################################################


@register_kl(Distribution, ExpandedDistribution)
def _kl_dist_expanded(p, q):
@dispatch(Distribution, ExpandedDistribution)
def kl_divergence(p, q):
kl = kl_divergence(p, q.base_dist)
shape = lax.broadcast_shapes(p.batch_shape, q.batch_shape)
return jnp.broadcast_to(kl, shape)


@register_kl(ExpandedDistribution, Distribution)
def _kl_expanded(p, q):
@dispatch(ExpandedDistribution, Distribution)
def kl_divergence(p, q):
kl = kl_divergence(p.base_dist, q)
shape = lax.broadcast_shapes(p.batch_shape, q.batch_shape)
return jnp.broadcast_to(kl, shape)


@register_kl(ExpandedDistribution, ExpandedDistribution)
def _kl_expanded_expanded(p, q):
@dispatch(ExpandedDistribution, ExpandedDistribution)
def kl_divergence(p, q):
kl = kl_divergence(p.base_dist, q.base_dist)
shape = lax.broadcast_shapes(p.batch_shape, q.batch_shape)
return jnp.broadcast_to(kl, shape)


@register_kl(Delta, Distribution)
def _kl_delta(p, q):
@dispatch(Delta, Distribution)
def kl_divergence(p, q):
return -q.log_prob(p.v)


@register_kl(Independent, Independent)
def _kl_independent_independent(p, q):
@dispatch(Delta, ExpandedDistribution)
def kl_divergence(p, q):
return -q.log_prob(p.v)


@dispatch(Independent, Independent)
def kl_divergence(p, q):
shared_ndims = min(p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)
p_ndims = p.reinterpreted_batch_ndims - shared_ndims
q_ndims = q.reinterpreted_batch_ndims - shared_ndims
Expand All @@ -171,8 +98,8 @@ def _kl_independent_independent(p, q):
return kl


@register_kl(MaskedDistribution, MaskedDistribution)
def _kl_masked_masked(p, q):
@dispatch(MaskedDistribution, MaskedDistribution)
def kl_divergence(p, q):
if p._mask is False or q._mask is False:
mask = False
elif p._mask is True:
Expand All @@ -192,15 +119,15 @@ def _kl_masked_masked(p, q):
return scale_and_mask(kl, mask=mask)


@register_kl(Normal, Normal)
def _kl_normal_normal(p, q):
@dispatch(Normal, Normal)
def kl_divergence(p, q):
var_ratio = jnp.square(p.scale / q.scale)
t1 = jnp.square((p.loc - q.loc) / q.scale)
return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio))


@register_kl(Dirichlet, Dirichlet)
def _kl_dirichlet_dirichlet(p, q):
@dispatch(Dirichlet, Dirichlet)
def kl_divergence(p, q):
# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
sum_p_concentration = p.concentration.sum(-1)
sum_q_concentration = q.concentration.sum(-1)
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
max-line-length = 120
exclude = docs/src, build, dist, .ipynb_checkpoints
ignore = W503,E203
per-file-ignores =
numpyro/distributions/kl.py:F811

[isort]
profile = black
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
install_requires=[
f"jax{_jax_version_constraints}",
f"jaxlib{_jaxlib_version_constraints}",
"multipledispatch",
"numpy",
"tqdm",
],
extras_require={
Expand Down

0 comments on commit 545db3f

Please sign in to comment.