-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ZeroSumNormal distribution #1751
Conversation
Still have some work to do on handling the support_shape / event_shape. Calling out that the logic behind having an event_shape=1 by default helps tests pass and works with the sampling shapes (and is similar to the conventions for the multivariatenormal), but makes model fitting very slow. Currently trying to figure this out. I think the support shape needs to default to none and use the default event_shape=() when a plate is involved. |
This seems to be working. Need to test some more edge cases and hopefully simplify the logic in the init function |
I think this is almost there - a question I still have is if I'm handling the support_shape argument correctly in conjunction with the event_shape and batch_shape. I was under the impression that the support_shape argument is most similar to an event shape, where if I have support_shape=(20,), and sample from it I end up with an array of shape=(20,), and if I then pass those samples through the log prob function I would get 1 value (shape=() ). But I think this gets a little weird in conjunction with the with numpyro.plate("n_categories", 20):
a_category = numpyro.sample("a_category", dist.ZeroSumNormal(1, n_zerosum_axes=1)) and theres a zero sum constraint placed over the 20 different category parameters, I feel like this is more of a batch_shape thing? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kylejcaron! First time I know about this distribution. Could you include some applications of the distribution in the description? A couple of comments:
- It seems to me that
n_zerosum_axes
is the number of event dimensions of this distribution. I think it is better to use only two parametersscale
andevent_dim
. You can add the interpretation for them in the description.
batch_shape = scale.shape[:-event_dim]
event_shape = scale.shape[-event_dim:]
n_zerosum_axes := event_dim
support_shape := event_shape
- you will need to define a Constraint for this; constraints.real is not the domain of this distribution
- if you plan to use this distribution as a latent variable, you will need to define a transform for the domain
- Looking at the PyMC doc, it is not clear to me how we can sample from this distribution. Could you add a better reference or describe it in the description?
- Do you know how to check for the correctness of log_prob? We have gof_test for other distributions but it seems that you disable it for this distribution.
The issue happens because you use |
Hmm, the formula in the description looks strange to me. If ZSN(scale) = Normal(...), then this is just a reparameterized version of Normal distribution. It is not clear to me how to enforce zero-sum there. |
Great to see a zero-sum distribution being implemented! Implementing it using a We're just writing up a manuscript for a sum-to-zero transformation/distribution whose |
@fehiepsi You're right, thats a likelihood of a normal distribution thats scaled to accommodate that zerosum axes are transformed to be shape n-1. I realized this is missing the zero sum transform like @tillahoffmann is mentioning, where the original input is transformed to a vector of size n-1 along each zerosum axis - I'm working on implementing that now. Here's the equation from the Olga and Barber paper (eq. 9, page 12) Question for both of you, are there any resources that discuss numpyro transforms a little further? I'm having trouble understanding how I would implement this - creating a ZeroSumTransform class seems simple, but how to call that is where I'm struggling. Most of the transforms I'm seeing right now take a base distribution and apply a transform to the base distribution - I cant tell if thats necessarily the right approach here. An example of what I'm referring to is the lognormal distribution, which is a transform on top of a normal distribution class LogNormal(TransformedDistribution):
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.positive
reparametrized_params = ["loc", "scale"]
def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
base_dist = Normal(loc, scale)
self.loc, self.scale = base_dist.loc, base_dist.scale
super(LogNormal, self).__init__(
base_dist, ExpTransform(), validate_args=validate_args
) I think this pattern wouldn't work well with the current sampling method, and potentially with the log prob as well? |
@fehiepsi Yes I can add some applications in the description! In the meantime, the typical use cases are to use this as an alternative for dummy-encoding, where there can be a parameter for every single category of interest while still having an intercept parameter without running into identifiability issues Technically I think n_zerosum_axes can be different from the event_dim. For example, lets say we have a 2d categorical variable of shape=(2, 50), we may only want to enforce a zerosum constraint over the last dim, i.e. with numpyro.plate("n_conditions", 2):
with numpyro.plate("n_categories", 20, dim=-1):
a_category = numpyro.sample("a_category", dist.ZeroSumNormal(1, n_zerosum_axes=1)) curious what you think given that example? |
I think you can use the LKJCholesky as an example. It has a non-trivial support which requires sort of lower cholesky transform. Re event dim: this denotes correlated dimensions, which is the last dimension of your ZeroSum example. Plate denote independent dimension. So you need to remove the category plate if you think that those categories are correlated (via zero sum enforcement). |
edit: note I'm assuming you mean for To make sure I have this right - lets say we're modeling the effects of the 50 states in the USA and some continuous regressor X. The traditional way to do this simply may be to have a separate intercept parameter for every state and no global intercept example 1: def model_example1(X, state, y=None):
with numpyro.plate("states", 50):
alpha_state = numpyro.sample("alpha_state", dist.Normal(0, 1))
beta = numpyro.sample("beta", dist.Normal( 0, 1) )
sigma = numpyro.sample("sigma", dist.Exponential(1) )
with numpyro.plate("data", len(X)):
numpyro.sample("obs", dist.Normal(alpha_state[state] + beta*X, sigma), obs=y) If we want to have a global intercept parameter for interpretability, the ZeroSumNormal could come in as follows example 2a: def model_example2a(X, state, y=None):
with numpyro.plate("states", 50):
alpha_state = numpyro.sample("alpha_state", dist.ZeroSumNormal(scale=1, n_zerosum_axes=1))
global_intercept = numpyro.sample("global_intercept", dist.Normal( 0, 1) )
beta = numpyro.sample("beta", dist.Normal( 0, 1) )
sigma = numpyro.sample("sigma", dist.Exponential(1) )
with numpyro.plate("data", len(X)):
numpyro.sample("obs", dist.Normal(global_intercept + alpha_state[state] + beta*X, sigma), obs=y) I didnt quite think about how the zerosum constraint creates correlated dimensions. Would that mean the actual better pattern to use instead of example 2a) is the following? example 2b: def model_example2b(X, state, y=None):
alpha_state = numpyro.sample("alpha_state", dist.ZeroSumNormal(scale=1, n_zerosum_axes=1, support_shape=(50,))
global_intercept = numpyro.sample("global_intercept", dist.Normal( 0, 1) )
beta = numpyro.sample("beta", dist.Normal( 0, 1) )
sigma = numpyro.sample("sigma", dist.Exponential(1) )
with numpyro.plate("data", len(X)):
numpyro.sample("obs", dist.Normal(global_intercept + alpha_state[state] + beta*X, sigma), obs=y) Or are you saying that we can keep the structure of example 2a (and just rename |
Nice to see this implemented in numpyro as well :-)
I didn't follow the discussion fully, but the current implementation in PyMC already scales linearly in n. We use the fact that the subspace defined by the zero-sum property is a linear subspace, so we can just map the vector space that is one dim smaller using a householder transformation. |
yes, that's my interpretation from the formula. The sum of
yeah, I think we can just follow the way samples are generated: pad the unconstrained vector by zero across all correlated dimensions, then subtract the mean, etc. Probably householder transform performs a better job. Could you provide some pointer, @aseyboldt ? |
You can find the code of the transformation here: https://github.com/pymc-devs/pymc/blob/main/pymc/distributions/transforms.py#L266 The covariance of the standard ZSN has one zero eigenvalue with eigenvector In this transformed space the ZSN is just a diagonal normal with the first variance zero, so we can get rid of the first entry of that array entirely, and get a map to If your design for distributions allows you to define the density on the transformed space directly, that makes things really easy here, since on the transformed space we have an iid normal distribution. If not (as is unfortunately the case in pymc) you have to trick a bit to still write down a density in the untransformed space that works out (See https://github.com/pymc-devs/pymc/blob/main/pymc/distributions/multivariate.py#L2841). Hope that helps, if not feel free to ask. To elaborate on the use case of the ZSN a bit, that would I think be mostly those two:
with pm.Model():
# Fixed for simplicity, nothing really changes if we estimate standard deviations
intercept_sd = 5
group_sd = 2
intercept = pm.Normal("intercept", sigma=intercept_sd)
group_effect = pm.Normal("group_effect", sigma=group_sd, dims="group")
mu = intercept + group_effect[group_idx]
sigma = 1
pm.Normal("data", mu=mu, sigma=sigma, observed=data) Here We can fix this by splitting the with pm.Model():
# Fixed for simplicity, nothing really changes if we estimate standard deviations
intercept_sd = 5
group_sd = 2
intercept = pm.Normal("intercept", sigma=intercept_sd)
group_mean = pm.Normal("group_mean", sigma=pt.sqrt(group_sd**2 / n_groups))
group_deflect = pm.ZeroSumNormal("group_deflect", sigma=group_sd, dims="group")
group_effect = group_mean + group_deflect
mu = intercept + group_mean + group_deflect[group_idx]
sigma = 1
pm.Normal("data", mu=mu, sigma=sigma, observed=data) But now we can merge with pm.Model():
# Fixed for simplicity, nothing really changes if we estimate standard deviations
intercept_sd = 5
group_sd = 2
intercept_plus_group_mean = pm.Normal("intercept_plus_group_mean", sigma=pt.sqrt(intercept_sd ** 2 + group_sd**2 / n_groups))
group_deflect = pm.ZeroSumNormal("group_deflect", sigma=group_sd, dims="group")
group_effect = group_mean + group_deflect
mu = intercept_plus_group_mean + group_deflect[group_idx]
sigma = 1
pm.Normal("data", mu=mu, sigma=sigma, observed=data) This also works for more complicated hierarchical models, in some cases you need several zero-sum axes then. If necessary, draws from the original intercept and group_mean can then be reconstructed after sampling. |
@aseyboldt thank you for the great response thats really helpful.
I was under the impression that in pymc the ZeroSumTransform is applied and then the logprob is contributed, is that not actually the case? Where exactly do the forward and backward transforms get applied? (I'm admittedly having trouble following the |
@@ -1380,6 +1382,92 @@ def __eq__(self, other): | |||
return jnp.array_equal(self.transition_matrix, other.transition_matrix) | |||
|
|||
|
|||
class ZeroSumTransform(Transform): | |||
"""A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AlexAndorra @aseyboldt @ricardoV94 same as I said above, this PR is nearing ready to go - let me know if there's more I can add to properly credit all of you and pymc
numpyro/distributions/continuous.py
Outdated
for axis in zero_sum_axes: | ||
theoretical_var *= 1 - 1 / self.event_shape[axis] | ||
|
||
return theoretical_var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the mean, we need to broadcast this to batch_shape + event_shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, updated!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending feedback from @tillahoffmann! Thanks for the awesome work, @kylejcaron!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! A few more little comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few more little ones, sorry this is such a piecemeal review.
Co-authored-by: Till Hoffmann <[email protected]>
dont worry about it I really appreciate all of the help, I still have a lot to learn - thank you! unless you can find any other issues, that discussion on the scale param might be the last unresolved piece |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Thank you, @kylejcaron!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also thanks for a detailed review, @tillahoffmann!
Thank you everyone for all of the help its been much appreciated! @fehiepsi I had to re-request review, the doctest for distributions/continuous.py was failing and didnt show up when I ran I was able to test it manually with |
* added zerosumnormal and tests * added edge case handling for support shape * removed commented out functions * added zerosumnormal to docs * fixed zerosumnormal support shape default * Added v1 of docstrings for zerosumnormal * updated zsn docstring * improved init shape handling for zerosumnormal * improved docstrings * added ZeroSumTransform * made n_zerosum_axes an attribute for the zerosumtransform * removed commented out lines * added zerosumtransform class * switched zsn from ParameterFreeTransform to Transform * changed ZeroSumNormal to transformed distibutrion * changed input to tuple for _transform_to_zero_sum * added forward and inverse shape to transform, fixed zero_sum constraint handling * fixed failing zsn tests * added docstring, removed whitespace, fixed missing import * fixed allclose to be assert allclose * linted and formatted * added sample code to docstring for zsn * updated docstring * removed list from ZeroSum constraint call * removed unneeded iteration, updated docstring * updated constraint code * added ZeroSumTransform to docs * fixed transform shapes * added doctest example for zsn * added constraint test * added zero_sum constraint to docs * added type hinting to transforms file * fixed docs formatting * moved skip zsn from test_gof earlier * reversed zerosumtransform * broadcasted mean and var of zsn * added stricter zero_sum constraint tol, improved mean and var functions * fixed _transform_to_zero_sum * removed shape promote from zsn, changed broadcast to zeros_like * chose better zsn test cases * Update zero_sum constraint feasible_like Co-authored-by: Till Hoffmann <[email protected]> * fixed docstring for doctests --------- Co-authored-by: Till Hoffmann <[email protected]>
This ports the ZeroSumNormal distribution (by @aseyboldt and @AlexAndorra) over from PyMC. Note that it is a special case of a normal distribution that enforces a zerosum constraint over specified dimensions.
This was a bit ambitious as a first contribution so I added the WIP tag on this to be safe.