Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZeroSumNormal distribution #1751

Merged
merged 43 commits into from
Mar 30, 2024
Merged

Add ZeroSumNormal distribution #1751

merged 43 commits into from
Mar 30, 2024

Conversation

kylejcaron
Copy link
Contributor

@kylejcaron kylejcaron commented Feb 29, 2024

This ports the ZeroSumNormal distribution (by @aseyboldt and @AlexAndorra) over from PyMC. Note that it is a special case of a normal distribution that enforces a zerosum constraint over specified dimensions.

$$ \begin{align*} ZSN(\sigma) = N \Big( 0, \sigma^2 (I - \tfrac{1}{n}J) \Big) \\ \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ n = \text{nbr of zero-sum axes} \end{align*} $$

This was a bit ambitious as a first contribution so I added the WIP tag on this to be safe.

@kylejcaron
Copy link
Contributor Author

kylejcaron commented Feb 29, 2024

Still have some work to do on handling the support_shape / event_shape.

Calling out that the logic behind having an event_shape=1 by default helps tests pass and works with the sampling shapes (and is similar to the conventions for the multivariatenormal), but makes model fitting very slow. Currently trying to figure this out.

I think the support shape needs to default to none and use the default event_shape=() when a plate is involved.

@kylejcaron
Copy link
Contributor Author

This seems to be working. Need to test some more edge cases and hopefully simplify the logic in the init function

@kylejcaron
Copy link
Contributor Author

I think this is almost there - a question I still have is if I'm handling the support_shape argument correctly in conjunction with the event_shape and batch_shape.

I was under the impression that the support_shape argument is most similar to an event shape, where if I have support_shape=(20,), and sample from it I end up with an array of shape=(20,), and if I then pass those samples through the log prob function I would get 1 value (shape=() ).

But I think this gets a little weird in conjunction with the plate or .expand method, which adjusts the batch_shape of the distribution, not the event shape. So if typical usage might be the following

with numpyro.plate("n_categories", 20):
     a_category = numpyro.sample("a_category", dist.ZeroSumNormal(1, n_zerosum_axes=1))

and theres a zero sum constraint placed over the 20 different category parameters, I feel like this is more of a batch_shape thing?

@kylejcaron kylejcaron marked this pull request as ready for review February 29, 2024 17:51
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kylejcaron! First time I know about this distribution. Could you include some applications of the distribution in the description? A couple of comments:

  • It seems to me that n_zerosum_axes is the number of event dimensions of this distribution. I think it is better to use only two parameters scale and event_dim. You can add the interpretation for them in the description.
batch_shape = scale.shape[:-event_dim]
event_shape = scale.shape[-event_dim:]
n_zerosum_axes := event_dim
support_shape := event_shape
  • you will need to define a Constraint for this; constraints.real is not the domain of this distribution
  • if you plan to use this distribution as a latent variable, you will need to define a transform for the domain
  • Looking at the PyMC doc, it is not clear to me how we can sample from this distribution. Could you add a better reference or describe it in the description?
  • Do you know how to check for the correctness of log_prob? We have gof_test for other distributions but it seems that you disable it for this distribution.

@fehiepsi
Copy link
Member

fehiepsi commented Mar 1, 2024

But I think this gets a little weird in conjunction with the plate or .expand method

The issue happens because you use constraints.real as the support of this distribution. You need to define something like LowerCholesky constraint.

@fehiepsi
Copy link
Member

fehiepsi commented Mar 1, 2024

Hmm, the formula in the description looks strange to me. If ZSN(scale) = Normal(...), then this is just a reparameterized version of Normal distribution. It is not clear to me how to enforce zero-sum there.

@tillahoffmann
Copy link
Contributor

Great to see a zero-sum distribution being implemented!

Implementing it using a ParameterFreeTransform could do the trick as described in Ogle and Baker (2020, p. 12). Since each sample of size n is constrained such that it sums to zero, posterior sampling is difficult unless the constrained vector is first transformed to an unconstrained vector of size n - 1 akin to the StickBreakingTransform for the Dirichlet distribution. This does however come at the cost of having to place a full-rank multivariate normal prior on the unconstrained space.

We're just writing up a manuscript for a sum-to-zero transformation/distribution whose log_prob evaluation scales linearly with n. Not quite there yet, but hope to be able to share soon.

@kylejcaron
Copy link
Contributor Author

kylejcaron commented Mar 5, 2024

Hmm, the formula in the description looks strange to me. If ZSN(scale) = Normal(...), then this is just a reparameterized version of Normal distribution. It is not clear to me how to enforce zero-sum there.

@fehiepsi You're right, thats a likelihood of a normal distribution thats scaled to accommodate that zerosum axes are transformed to be shape n-1. I realized this is missing the zero sum transform like @tillahoffmann is mentioning, where the original input is transformed to a vector of size n-1 along each zerosum axis - I'm working on implementing that now. Here's the equation from the Olga and Barber paper (eq. 9, page 12)

$$ \begin{gather} \epsilon_{-J} \sim Normal_{J-1}(0,\Sigma) \\ \Sigma_{j,k} = - \frac{\sigma^2}{J} ,\ j \neq k \ \text{and} \ \Sigma_{j,j} = \sigma^2_{\epsilon} \\ \epsilon_{J} = - \Sigma_{j=1}^{J-1} \epsilon_j \end{gather} $$

Question for both of you, are there any resources that discuss numpyro transforms a little further? I'm having trouble understanding how I would implement this - creating a ZeroSumTransform class seems simple, but how to call that is where I'm struggling. Most of the transforms I'm seeing right now take a base distribution and apply a transform to the base distribution - I cant tell if thats necessarily the right approach here.

An example of what I'm referring to is the lognormal distribution, which is a transform on top of a normal distribution

class LogNormal(TransformedDistribution):
    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
    support = constraints.positive
    reparametrized_params = ["loc", "scale"]

    def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
        base_dist = Normal(loc, scale)
        self.loc, self.scale = base_dist.loc, base_dist.scale
        super(LogNormal, self).__init__(
            base_dist, ExpTransform(), validate_args=validate_args
        )

I think this pattern wouldn't work well with the current sampling method, and potentially with the log prob as well?

@kylejcaron
Copy link
Contributor Author

kylejcaron commented Mar 5, 2024

  • It seems to me that n_zerosum_axes is the number of event dimensions of this distribution. I think it is better to use only two parameters scale and event_dim. You can add the interpretation for them in the description.

@fehiepsi Yes I can add some applications in the description! In the meantime, the typical use cases are to use this as an alternative for dummy-encoding, where there can be a parameter for every single category of interest while still having an intercept parameter without running into identifiability issues

Technically I think n_zerosum_axes can be different from the event_dim. For example, lets say we have a 2d categorical variable of shape=(2, 50), we may only want to enforce a zerosum constraint over the last dim, i.e.

with numpyro.plate("n_conditions", 2):
     with numpyro.plate("n_categories", 20, dim=-1):
          a_category = numpyro.sample("a_category", dist.ZeroSumNormal(1, n_zerosum_axes=1))

curious what you think given that example?

@fehiepsi
Copy link
Member

fehiepsi commented Mar 5, 2024

I think you can use the LKJCholesky as an example. It has a non-trivial support which requires sort of lower cholesky transform. Re event dim: this denotes correlated dimensions, which is the last dimension of your ZeroSum example. Plate denote independent dimension. So you need to remove the category plate if you think that those categories are correlated (via zero sum enforcement).

@kylejcaron
Copy link
Contributor Author

kylejcaron commented Mar 5, 2024

I think you can use the LKJCholesky as an example. It has a non-trivial support which requires sort of lower cholesky transform. Re event dim: this denotes correlated dimensions, which is the last dimension of your ZeroSum example. Plate denote independent dimension. So you need to remove the category plate if you think that those categories are correlated (via zero sum enforcement).

edit: note I'm assuming you mean for event_dim to be equivalent to the previous use of n_zerosum_axes in everything below

To make sure I have this right - lets say we're modeling the effects of the 50 states in the USA and some continuous regressor X. The traditional way to do this simply may be to have a separate intercept parameter for every state and no global intercept

example 1:

def model_example1(X, state, y=None):
    with numpyro.plate("states", 50):
        alpha_state = numpyro.sample("alpha_state", dist.Normal(0, 1))

    beta = numpyro.sample("beta", dist.Normal( 0, 1) )
    sigma = numpyro.sample("sigma", dist.Exponential(1) )

    with numpyro.plate("data", len(X)):
        numpyro.sample("obs", dist.Normal(alpha_state[state] + beta*X, sigma), obs=y)

If we want to have a global intercept parameter for interpretability, the ZeroSumNormal could come in as follows

example 2a:

def model_example2a(X, state, y=None):
    with numpyro.plate("states", 50):
        alpha_state = numpyro.sample("alpha_state", dist.ZeroSumNormal(scale=1, n_zerosum_axes=1))
   
    global_intercept = numpyro.sample("global_intercept", dist.Normal( 0, 1) )
    beta = numpyro.sample("beta", dist.Normal( 0, 1) )
    sigma = numpyro.sample("sigma", dist.Exponential(1) )

    with numpyro.plate("data", len(X)):
        numpyro.sample("obs", dist.Normal(global_intercept + alpha_state[state] + beta*X, sigma), obs=y)

I didnt quite think about how the zerosum constraint creates correlated dimensions. Would that mean the actual better pattern to use instead of example 2a) is the following?

example 2b:

def model_example2b(X, state, y=None):
   
    alpha_state = numpyro.sample("alpha_state", dist.ZeroSumNormal(scale=1, n_zerosum_axes=1, support_shape=(50,))
   
    global_intercept = numpyro.sample("global_intercept", dist.Normal( 0, 1) )
    beta = numpyro.sample("beta", dist.Normal( 0, 1) )
    sigma = numpyro.sample("sigma", dist.Exponential(1) )

    with numpyro.plate("data", len(X)):
        numpyro.sample("obs", dist.Normal(global_intercept + alpha_state[state] + beta*X, sigma), obs=y)

Or are you saying that we can keep the structure of example 2a (and just rename n_zerosum_axes to event_dim) and then have internal logic that separates out the event_shape so that the dims (the 50 states in this case) arent independent of eachother

@aseyboldt
Copy link

Nice to see this implemented in numpyro as well :-)

Implementing it using a ParameterFreeTransform could do the trick as described in Ogle and Baker (2020, p. 12). Since each sample of size n is constrained such that it sums to zero, posterior sampling is difficult unless the constrained vector is first transformed to an unconstrained vector of size n - 1 akin to the StickBreakingTransform for the Dirichlet distribution. This does however come at the cost of having to place a full-rank multivariate normal prior on the unconstrained space.

We're just writing up a manuscript for a sum-to-zero transformation/distribution whose log_prob evaluation scales linearly with n. Not quite there yet, but hope to be able to share soon.

I didn't follow the discussion fully, but the current implementation in PyMC already scales linearly in n. We use the fact that the subspace defined by the zero-sum property is a linear subspace, so we can just map the vector space that is one dim smaller using a householder transformation.

@fehiepsi
Copy link
Member

fehiepsi commented Mar 5, 2024

example 2b:

yes, that's my interpretation from the formula. The sum of alpha_state is zero hence they are correlated.

subspace defined by the zero-sum property is a linear subspace

yeah, I think we can just follow the way samples are generated: pad the unconstrained vector by zero across all correlated dimensions, then subtract the mean, etc. Probably householder transform performs a better job. Could you provide some pointer, @aseyboldt ?

@aseyboldt
Copy link

aseyboldt commented Mar 6, 2024

You can find the code of the transformation here: https://github.com/pymc-devs/pymc/blob/main/pymc/distributions/transforms.py#L266

The covariance of the standard ZSN has one zero eigenvalue with eigenvector $w = (1, 1, \dots)^T$, while all other eigenvalues are equal to one. We would like to find an orthogonal map that maps the span of $w$ to the span of $e_1 = (1, 0, 0, \dots)^T$. We can write Householder transformations as $H_v = I - 2vv^T$, where $||v|| = 1$, so we would like to find a $v$ such that $H_v w = \alpha e_1$. For instance $v = \frac{w + ||w|| e_1}{||w + ||w|| e_1||}$ does the job (see for instance https://de.wikipedia.org/wiki/Householdertransformation#Konstruktion_einer_spezifischen_Spiegelung The English wikipedia article doesn't contain that, but I'm sure there are tons of references out there somewhere...). $H_v$ is its own inverse, so we can map between the transformed and untransformed spaces by just applying $H_v$, and we can do that without computing the whole matrix $H_v x = (I - 2vv^T)x = x - 2v(v^Tx)$. And since $H_v$ is orthogonal, the jacobian det is one.

In this transformed space the ZSN is just a diagonal normal with the first variance zero, so we can get rid of the first entry of that array entirely, and get a map to $n-1$ dimensional space.

If your design for distributions allows you to define the density on the transformed space directly, that makes things really easy here, since on the transformed space we have an iid normal distribution. If not (as is unfortunately the case in pymc) you have to trick a bit to still write down a density in the untransformed space that works out (See https://github.com/pymc-devs/pymc/blob/main/pymc/distributions/multivariate.py#L2841).

Hope that helps, if not feel free to ask.

To elaborate on the use case of the ZSN a bit, that would I think be mostly those two:

  • In regressions with catogorical outcomes we can use ZSN dists instead of using one outcome as the reference. In a frequentist setting the choice of reference doesn't matter, but if we have priors it often does. But if no outcome is special in a way where it would make sense to use it as a reference we can use the ZSN to avoid that arbitrary choice.
  • Hierarchical models with categorical predictors are often over parameterized if we don't take some care. Say for instance we have a model like this (using pymc code, by I hope you still get the idea):
with pm.Model():
    # Fixed for simplicity, nothing really changes if we estimate standard deviations
    intercept_sd = 5
    group_sd = 2

    intercept = pm.Normal("intercept", sigma=intercept_sd)
    group_effect = pm.Normal("group_effect", sigma=group_sd, dims="group")

    mu = intercept + group_effect[group_idx]

    sigma = 1
    pm.Normal("data", mu=mu, sigma=sigma, observed=data)

Here mu has n_groups degrees of freedom, but we have n_groups + 1 degrees of freedom in the model. So if the data size increases (while keeping the same number of groups) there must be a 1-dimensional subspace where the posterior variance stays constant, while all other posterior variances go to zero. But this means that the condition number of the posterior variance goes to infinity. We don't want that...

We can fix this by splitting the group_effect into two parts. For any $x ~ N(0, \sigma^2 I)$ we can always write it as $$x_{\text{deflect}} \sim ZSN(\sigma^2), \quad x_{\text{mean}} \sim N(0, \frac{\sigma^2}{\text{len}(x)}), \quad x = x_{\text{deflect}} + x_{\text{mean}}$$, so

with pm.Model():
    # Fixed for simplicity, nothing really changes if we estimate standard deviations
    intercept_sd = 5
    group_sd = 2

    intercept = pm.Normal("intercept", sigma=intercept_sd)
    group_mean = pm.Normal("group_mean", sigma=pt.sqrt(group_sd**2 / n_groups))
    group_deflect = pm.ZeroSumNormal("group_deflect", sigma=group_sd, dims="group")
    group_effect = group_mean + group_deflect

    mu = intercept + group_mean + group_deflect[group_idx]

    sigma = 1
    pm.Normal("data", mu=mu, sigma=sigma, observed=data)

But now we can merge intercept + group_mean into one normally distributed random variable, and that way get rid of the extra degree of freedom:

with pm.Model():
    # Fixed for simplicity, nothing really changes if we estimate standard deviations
    intercept_sd = 5
    group_sd = 2

    intercept_plus_group_mean = pm.Normal("intercept_plus_group_mean", sigma=pt.sqrt(intercept_sd ** 2 + group_sd**2 / n_groups))
    group_deflect = pm.ZeroSumNormal("group_deflect", sigma=group_sd, dims="group")
    group_effect = group_mean + group_deflect

    mu = intercept_plus_group_mean + group_deflect[group_idx]

    sigma = 1
    pm.Normal("data", mu=mu, sigma=sigma, observed=data)

This also works for more complicated hierarchical models, in some cases you need several zero-sum axes then.

If necessary, draws from the original intercept and group_mean can then be reconstructed after sampling.
A nice bonus of this is that I noticed in many applications I really want to know group_deflect, and not group_effect, especially if there really is only a fixed number of groups, and it doesn't really make sense to think about the infinite population of groups. So for instance if each group is a US state, do we really want to figure out how California differs from the mean of a (completely imaginary) infinite population of US states, or do we want to know how it differs from the mean of the actual US states?

@kylejcaron
Copy link
Contributor Author

kylejcaron commented Mar 7, 2024

@aseyboldt thank you for the great response thats really helpful.

you have to trick a bit to still write down a density in the untransformed space that works out

I was under the impression that in pymc the ZeroSumTransform is applied and then the logprob is contributed, is that not actually the case? Where exactly do the forward and backward transforms get applied? (I'm admittedly having trouble following the _default_transform logic)

@@ -1380,6 +1382,92 @@ def __eq__(self, other):
return jnp.array_equal(self.transition_matrix, other.transition_matrix)


class ZeroSumTransform(Transform):
"""A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3]
Copy link
Contributor Author

@kylejcaron kylejcaron Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexAndorra @aseyboldt @ricardoV94 same as I said above, this PR is nearing ready to go - let me know if there's more I can add to properly credit all of you and pymc

for axis in zero_sum_axes:
theoretical_var *= 1 - 1 / self.event_shape[axis]

return theoretical_var
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the mean, we need to broadcast this to batch_shape + event_shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, updated!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending feedback from @tillahoffmann! Thanks for the awesome work, @kylejcaron!

numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
Copy link
Contributor

@tillahoffmann tillahoffmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! A few more little comments.

numpyro/distributions/constraints.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/transforms.py Outdated Show resolved Hide resolved
Copy link
Contributor

@tillahoffmann tillahoffmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few more little ones, sorry this is such a piecemeal review.

numpyro/distributions/constraints.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Show resolved Hide resolved
numpyro/distributions/transforms.py Outdated Show resolved Hide resolved
@kylejcaron
Copy link
Contributor Author

Few more little ones, sorry this is such a piecemeal review.

dont worry about it I really appreciate all of the help, I still have a lot to learn - thank you! unless you can find any other issues, that discussion on the scale param might be the last unresolved piece

Copy link
Contributor

@tillahoffmann tillahoffmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Thank you, @kylejcaron!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also thanks for a detailed review, @tillahoffmann!

@kylejcaron
Copy link
Contributor Author

Also thanks for a detailed review, @tillahoffmann!

Thank you everyone for all of the help its been much appreciated! @fehiepsi I had to re-request review, the doctest for distributions/continuous.py was failing and didnt show up when I ran make doctest (maybe because the nested_sampling module was failing first?).

I was able to test it manually with python -m doctest -v numpyro/distributions/continuous.py and can confirm its fixed

@kylejcaron kylejcaron changed the title [WIP] ZeroSumNormal distribution Add ZeroSumNormal distribution Mar 29, 2024
@fehiepsi fehiepsi merged commit 68eb218 into pyro-ppl:master Mar 30, 2024
4 checks passed
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
* added zerosumnormal and tests

* added edge case handling for support shape

* removed commented out functions

* added zerosumnormal to docs

* fixed zerosumnormal support shape default

* Added v1 of docstrings for zerosumnormal

* updated zsn docstring

* improved init shape handling for zerosumnormal

* improved docstrings

* added ZeroSumTransform

* made n_zerosum_axes an attribute for the zerosumtransform

* removed commented out lines

* added zerosumtransform class

* switched zsn from ParameterFreeTransform to Transform

* changed ZeroSumNormal to transformed distibutrion

* changed input to tuple for _transform_to_zero_sum

* added forward and inverse shape to transform, fixed zero_sum constraint handling

* fixed failing zsn tests

* added docstring, removed whitespace, fixed missing import

* fixed allclose to be assert allclose

* linted and formatted

* added sample code to docstring for zsn

* updated docstring

* removed list from ZeroSum constraint call

* removed unneeded iteration, updated docstring

* updated constraint code

* added ZeroSumTransform to docs

* fixed transform shapes

* added doctest example for zsn

* added constraint test

* added zero_sum constraint to docs

* added type hinting to transforms file

* fixed docs formatting

* moved skip zsn from test_gof earlier

* reversed zerosumtransform

* broadcasted mean and var of zsn

* added stricter zero_sum constraint tol, improved mean and var functions

* fixed _transform_to_zero_sum

* removed shape promote from zsn, changed broadcast to zeros_like

* chose better zsn test cases

* Update zero_sum constraint feasible_like

Co-authored-by: Till Hoffmann <[email protected]>

* fixed docstring for doctests

---------

Co-authored-by: Till Hoffmann <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants