Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use multipledispatch for kl_registry #1252

Merged
merged 6 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ ipython
jax>=0.2.11
jaxlib>=0.1.62
jaxns>=0.0.7
optax>=0.0.6
multipledispatch
nbsphinx==0.8.6
numpy
optax>=0.0.6
readthedocs-sphinx-search==0.1.0
sphinx==4.0.3
sphinx-gallery
Expand Down
119 changes: 23 additions & 96 deletions numpyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from functools import total_ordering
import warnings
from multipledispatch import dispatch

from jax import lax
import jax.numpy as jnp
Expand All @@ -42,124 +41,52 @@
)
from numpyro.distributions.util import scale_and_mask, sum_rightmost

_KL_REGISTRY = (
{}
) # Source of truth mapping a few general (type, type) pairs to functions.
_KL_MEMOIZE = (
{}
) # Memoized version mapping many specific (type, type) pairs to functions.


def register_kl(type_p, type_q):
if not isinstance(type_p, type) and issubclass(type_p, Distribution):
raise TypeError(
"Expected type_p to be a Distribution subclass but got {}".format(type_p)
)
if not isinstance(type_q, type) and issubclass(type_q, Distribution):
raise TypeError(
"Expected type_q to be a Distribution subclass but got {}".format(type_q)
)

def decorator(fun):
_KL_REGISTRY[type_p, type_q] = fun
_KL_MEMOIZE.clear() # reset since lookup order may have changed
return fun

return decorator


@total_ordering
class _Match(object):
__slots__ = ["types"]

def __init__(self, *types):
self.types = types

def __eq__(self, other):
return self.types == other.types

def __le__(self, other):
for x, y in zip(self.types, other.types):
if not issubclass(x, y):
return False
if x is not y:
break
return True


def _dispatch_kl(type_p, type_q):
"""
Find the most specific approximate match, assuming single inheritance.
"""
matches = [
(super_p, super_q)
for super_p, super_q in _KL_REGISTRY
if issubclass(type_p, super_p) and issubclass(type_q, super_q)
]
if not matches:
return NotImplemented
# Check that the left- and right- lexicographic orders agree.
left_p, left_q = min(_Match(*m) for m in matches).types
right_q, right_p = min(_Match(*reversed(m)) for m in matches).types
left_fun = _KL_REGISTRY[left_p, left_q]
right_fun = _KL_REGISTRY[right_p, right_q]
if left_fun is not right_fun:
warnings.warn(
"Ambiguous kl_divergence({}, {}). Please register_kl({}, {})".format(
type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__
),
RuntimeWarning,
)
return left_fun


def kl_divergence(p, q):
r"""
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
"""
try:
fun = _KL_MEMOIZE[type(p), type(q)]
except KeyError:
fun = _dispatch_kl(type(p), type(q))
_KL_MEMOIZE[type(p), type(q)] = fun
if fun is NotImplemented:
raise NotImplementedError
return fun(p, q)
raise NotImplementedError


################################################################################
# KL Divergence Implementations
################################################################################


@register_kl(Distribution, ExpandedDistribution)
def _kl_dist_expanded(p, q):
@dispatch(Distribution, ExpandedDistribution)
def kl_divergence(p, q):
kl = kl_divergence(p, q.base_dist)
shape = lax.broadcast_shapes(p.batch_shape, q.batch_shape)
return jnp.broadcast_to(kl, shape)


@register_kl(ExpandedDistribution, Distribution)
def _kl_expanded(p, q):
@dispatch(ExpandedDistribution, Distribution)
def kl_divergence(p, q):
kl = kl_divergence(p.base_dist, q)
shape = lax.broadcast_shapes(p.batch_shape, q.batch_shape)
return jnp.broadcast_to(kl, shape)


@register_kl(ExpandedDistribution, ExpandedDistribution)
def _kl_expanded_expanded(p, q):
@dispatch(ExpandedDistribution, ExpandedDistribution)
def kl_divergence(p, q):
kl = kl_divergence(p.base_dist, q.base_dist)
shape = lax.broadcast_shapes(p.batch_shape, q.batch_shape)
return jnp.broadcast_to(kl, shape)


@register_kl(Delta, Distribution)
def _kl_delta(p, q):
@dispatch(Delta, Distribution)
def kl_divergence(p, q):
return -q.log_prob(p.v)


@register_kl(Independent, Independent)
def _kl_independent_independent(p, q):
@dispatch(Delta, ExpandedDistribution)
def kl_divergence(p, q):
return -q.log_prob(p.v)


@dispatch(Independent, Independent)
def kl_divergence(p, q):
shared_ndims = min(p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)
p_ndims = p.reinterpreted_batch_ndims - shared_ndims
q_ndims = q.reinterpreted_batch_ndims - shared_ndims
Expand All @@ -171,8 +98,8 @@ def _kl_independent_independent(p, q):
return kl


@register_kl(MaskedDistribution, MaskedDistribution)
def _kl_masked_masked(p, q):
@dispatch(MaskedDistribution, MaskedDistribution)
def kl_divergence(p, q):
if p._mask is False or q._mask is False:
mask = False
elif p._mask is True:
Expand All @@ -192,15 +119,15 @@ def _kl_masked_masked(p, q):
return scale_and_mask(kl, mask=mask)


@register_kl(Normal, Normal)
def _kl_normal_normal(p, q):
@dispatch(Normal, Normal)
def kl_divergence(p, q):
var_ratio = jnp.square(p.scale / q.scale)
t1 = jnp.square((p.loc - q.loc) / q.scale)
return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio))


@register_kl(Dirichlet, Dirichlet)
def _kl_dirichlet_dirichlet(p, q):
@dispatch(Dirichlet, Dirichlet)
def kl_divergence(p, q):
# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
sum_p_concentration = p.concentration.sum(-1)
sum_q_concentration = q.concentration.sum(-1)
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
max-line-length = 120
exclude = docs/src, build, dist, .ipynb_checkpoints
ignore = W503,E203
per-file-ignores =
numpyro/distributions/kl.py:F811

[isort]
profile = black
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
install_requires=[
f"jax{_jax_version_constraints}",
f"jaxlib{_jaxlib_version_constraints}",
"multipledispatch",
"numpy",
"tqdm",
],
extras_require={
Expand Down