Skip to content

Commit

Permalink
Add Beta, Dirichlet, and Binomial distributions (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fritzo committed Apr 8, 2019
1 parent f74df0a commit 48fa3f7
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 8 deletions.
148 changes: 141 additions & 7 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pyro.distributions as dist
import torch
from pyro.distributions.util import broadcast_shape
from six import add_metaclass

from pyro.distributions.util import broadcast_shape
Expand All @@ -16,7 +17,7 @@
from funsor.gaussian import Gaussian
from funsor.interpreter import interpretation
from funsor.terms import Funsor, FunsorMeta, Number, Subs, Variable, eager, lazy, to_funsor
from funsor.torch import Tensor, align_tensors, materialize
from funsor.torch import Tensor, align_tensors, materialize, torch_stack


def numbers_to_tensors(*args):
Expand Down Expand Up @@ -109,6 +110,7 @@ class Bernoulli(Distribution):
@staticmethod
def _fill_defaults(probs, value='value'):
probs = to_funsor(probs)
assert probs.dtype == "real"
value = to_funsor(value, reals())
return probs, value

Expand All @@ -121,12 +123,68 @@ def eager_categorical(probs, value):
return Bernoulli.eager_log_prob(probs=probs, value=value)


class Beta(Distribution):
dist_class = dist.Beta

@staticmethod
def _fill_defaults(concentration1, concentration0, value='value'):
concentration1 = to_funsor(concentration1, reals())
concentration0 = to_funsor(concentration0, reals())
value = to_funsor(value, reals())
return concentration1, concentration0, value

def __init__(self, concentration1, concentration0, value=None):
super(Beta, self).__init__(concentration1, concentration0, value)


@eager.register(Beta, Tensor, Tensor, Tensor)
def eager_beta(concentration1, concentration0, value):
return Beta.eager_log_prob(concentration1=concentration1,
concentration0=concentration0,
value=value)


@eager.register(Beta, Funsor, Funsor, Funsor)
def eager_beta(concentration1, concentration0, value):
concentration = torch_stack((concentration0, concentration1))
value = torch_stack((1 - value, value))
return Dirichlet(concentration, value=value)


class Binomial(Distribution):
dist_class = dist.Binomial

@staticmethod
def _fill_defaults(total_count, probs, value='value'):
total_count = to_funsor(total_count, reals())
probs = to_funsor(probs)
assert probs.dtype == "real"
value = to_funsor(value, reals())
return total_count, probs, value

def __init__(self, total_count, probs, value=None):
super(Binomial, self).__init__(total_count, probs, value)


@eager.register(Binomial, Tensor, Tensor, Tensor)
def eager_binomial(total_count, probs, value):
return Binomial.eager_log_prob(total_count=total_count, probs=probs, value=value)


@eager.register(Binomial, Funsor, Funsor, Funsor)
def eager_binomial(total_count, probs, value):
probs = torch_stack((1 - probs, probs))
value = torch_stack((total_count - value, value))
return Multinomial(total_count, probs, value=value)


class Categorical(Distribution):
dist_class = dist.Categorical

@staticmethod
def _fill_defaults(probs, value='value'):
probs = to_funsor(probs)
assert probs.dtype == "real"
value = to_funsor(value, bint(probs.output.shape[0]))
return probs, value

Expand Down Expand Up @@ -156,7 +214,7 @@ class Delta(Distribution):
@staticmethod
def _fill_defaults(v, log_density=0, value='value'):
v = to_funsor(v)
log_density = to_funsor(log_density)
log_density = to_funsor(log_density, reals())
value = to_funsor(value, v.output)
return v, log_density, value

Expand Down Expand Up @@ -188,6 +246,50 @@ def eager_delta(v, log_density, value):
return funsor.delta.Delta(v.name, value, log_density)


class Dirichlet(Distribution):
dist_class = dist.Dirichlet

@staticmethod
def _fill_defaults(concentration, value='value'):
concentration = to_funsor(concentration)
assert concentration.dtype == "real"
assert len(concentration.output.shape) == 1
dim = concentration.output.shape[0]
value = to_funsor(value, reals(dim))
return concentration, value

def __init__(self, concentration, value='value'):
super(Dirichlet, self).__init__(concentration, value)


@eager.register(Dirichlet, Tensor, Tensor)
def eager_dirichlet(concentration, value):
return Dirichlet.eager_log_prob(concentration=concentration, value=value)


class DirichletMultinomial(Distribution):
dist_class = dist.DirichletMultinomial

@staticmethod
def _fill_defaults(concentration, total_count=1, value='value'):
concentration = to_funsor(concentration)
assert concentration.dtype == "real"
assert len(concentration.output.shape) == 1
total_count = to_funsor(total_count, reals())
dim = concentration.output.shape[0]
value = to_funsor(value, reals(dim)) # Should this be bint(total_count)?
return concentration, total_count, value

def __init__(self, concentration, total_count, value='value'):
super(DirichletMultinomial, self).__init__(concentration, total_count, value)


@eager.register(DirichletMultinomial, Tensor, Tensor, Tensor)
def eager_dirichlet_multinomial(concentration, total_count, value):
return DirichletMultinomial.eager_log_prob(
concentration=concentration, total_count=total_count, value=value)


def LogNormal(loc, scale, value='value'):
loc, scale, y = Normal._fill_defaults(loc, scale, value)
t = ops.exp
Expand All @@ -196,16 +298,42 @@ def LogNormal(loc, scale, value='value'):
return Normal(loc, scale, x) - log_abs_det_jacobian


class Multinomial(Distribution):
dist_class = dist.Multinomial

@staticmethod
def _fill_defaults(total_count, probs, value='value'):
total_count = to_funsor(total_count, reals())
probs = to_funsor(probs)
assert probs.dtype == "real"
assert len(probs.output.shape) == 1
value = to_funsor(value, probs.output)
return total_count, probs, value

def __init__(self, total_count, probs, value=None):
super(Multinomial, self).__init__(total_count, probs, value)


@eager.register(Multinomial, Tensor, Tensor, Tensor)
def eager_multinomial(total_count, probs, value):
# Multinomial.log_prob() supports inhomogeneous total_count only by
# avoiding passing total_count to the constructor.
inputs, (total_count, probs, value) = align_tensors(total_count, probs, value)
shape = broadcast_shape(total_count.shape + (1,), probs.shape, value.shape)
probs = Tensor(probs.expand(shape), inputs)
value = Tensor(value.expand(shape), inputs)
total_count = Number(total_count.max().item()) # Used by distributions validation code.
return Multinomial.eager_log_prob(total_count=total_count, probs=probs, value=value)


class Normal(Distribution):
dist_class = dist.Normal

@staticmethod
def _fill_defaults(loc, scale, value='value'):
loc = to_funsor(loc)
scale = to_funsor(scale)
assert loc.output == reals()
assert scale.output == reals()
value = to_funsor(value, loc.output)
loc = to_funsor(loc, reals())
scale = to_funsor(scale, reals())
value = to_funsor(value, reals())
return loc, scale, value

def __init__(self, loc, scale, value='value'):
Expand Down Expand Up @@ -313,10 +441,16 @@ def eager_mvn(loc, scale_tril, value):


__all__ = [
'Bernoulli',
'Beta',
'Binomial',
'Categorical',
'Delta',
'Dirichlet',
'DirichletMultinomial',
'Distribution',
'LogNormal',
'Multinomial',
'MultivariateNormal',
'Normal',
]
25 changes: 25 additions & 0 deletions funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,31 @@ def torch_tensordot(x, y, dims):
return Function(fn, output, (x, y))


def _torch_stack(dim, *parts):
return torch.stack(parts, dim=dim)


def torch_stack(parts, dim=0):
"""
Wrapper around :func:`torch.stack` to operate on real-valued Funsors.
Note this operates only on the ``output`` tensor. To stack funsors in a
new named dim, instead use :class:`~funsor.terms.Stack`.
"""
assert isinstance(dim, int)
assert isinstance(parts, tuple)
assert len(set(x.output for x in parts)) == 1
shape = parts[0].output.shape
if dim >= 0:
dim = dim - len(shape) - 1
assert dim < 0
split = dim + len(shape) + 1
shape = shape[:split] + (len(parts),) + shape[split:]
output = Domain(shape, parts[0].dtype)
fn = functools.partial(_torch_stack, dim)
return Function(fn, output, parts)


################################################################################
# Register Ops
################################################################################
Expand Down
Loading

0 comments on commit 48fa3f7

Please sign in to comment.