Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add log normal negative binomial distributions #3010

Merged
merged 16 commits into from
Jan 27, 2022
Merged
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ LKJCorrCholesky
:undoc-members:
:show-inheritance:

LogNormalNegativeBinomialDistribution
-------------------------------------
.. autoclass:: pyro.distributions.LogNormalNegativeBinomialDistribution
:members:
:undoc-members:
:show-inheritance:

Logistic
--------
.. autoclass:: pyro.distributions.Logistic
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from pyro.distributions.improper_uniform import ImproperUniform
from pyro.distributions.inverse_gamma import InverseGamma
from pyro.distributions.lkj import LKJ, LKJCorrCholesky
from pyro.distributions.log_normal_negative_binomial import LogNormalNegativeBinomial
from pyro.distributions.logistic import Logistic, SkewLogistic
from pyro.distributions.mixture import MaskedMixture
from pyro.distributions.multivariate_studentt import MultivariateStudentT
Expand Down Expand Up @@ -124,6 +125,7 @@
"LKJCorrCholesky",
"LinearHMM",
"Logistic",
"LogNormalNegativeBinomial",
"MaskedDistribution",
"MaskedMixture",
"MixtureOfDiagNormals",
Expand Down
58 changes: 58 additions & 0 deletions pyro/distributions/log_normal_negative_binomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
from numpy.polynomial.hermite import hermgauss

import torch
from torch.distributions import constraints

from pyro.distributions import NegativeBinomial, TorchDistribution
martinjankowiak marked this conversation as resolved.
Show resolved Hide resolved
from pyro.distributions.util import broadcast_shape


def get_quad_rule(num_quad, prototype_tensor):
martinjankowiak marked this conversation as resolved.
Show resolved Hide resolved
quad_rule = hermgauss(num_quad)
quad_points = quad_rule[0] * np.sqrt(2.0)
log_weights = np.log(quad_rule[1]) - 0.5 * np.log(np.pi)
return torch.from_numpy(quad_points).type_as(prototype_tensor), \
torch.from_numpy(log_weights).type_as(prototype_tensor)


class LogNormalNegativeBinomial(TorchDistribution):
"""
A three-parameter generalization of the Negative Binomial distribution [1].
It can be understood as a continuous mixture of Negative Binomial distributions
in which we inject Normally-distributed noise into the logits of the Negative
Binomial distribution:

:math:`\rm{LNNB}(\rm{total_count}=\nu, \rm{logits}=\ell, \rm{multiplicative_noise_scale}=sigma) = \int d\epsilon
\mathcal{N}(\epsilon | 0, \sigma) \rm{NB}(\rm{total_count}=\nu, \rm{logits}=\ell + \epsilon)`

References:
[1] "Lognormal and Gamma Mixed Negative Binomial Regression,"
Mingyuan Zhou, Lingbo Li, David Dunson, and Lawrence Carin.

:param total_count: non-negative number of negative Bernoulli trials.
:type total_count: float or torch.Tensor
:param torch.Tensor logits: Event log-odds for probabilities of success for underlying
Negative Binomial distribution.
:param num_quad_points: Number of quadrature points used to compute the (approximate) `log_prob`.
Defaults to 8.
:type num_quad_points: int
"""
arg_constraints = {'total_count': constraints.greater_than_eq(0),
'logits': constraints.real,
'multiplicative_noise_scale': constraints.positive}
support = constraints.nonnegative_integer

def __init__(self, total_count, logits, multiplicative_noise_scale,
num_quad_points=8, validate_args=None):
self.quad_points, self.log_weights = get_quad_rule(num_quad_points, logits)
quad_logits = logits.unsqueeze(-1) + multiplicative_noise_scale.unsqueeze(-1) * self.quad_points
self.nb = NegativeBinomial(total_count=total_count.unsqueeze(-1), logits=quad_logits)
self.multiplicative_noise_scale = multiplicative_noise_scale

def log_prob(self, value):
nb_log_prob = self.nb.log_prob(value.unsqueeze(-1))
return torch.logsumexp(self.log_weights + nb_log_prob, axis=-1)

def sample(self, sample_shape=torch.Size()):
raise NotImplementedError