Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify how to sample the centered parameter in LocScaleReparam #1598

Closed
Madhav-Kanda opened this issue Jun 1, 2023 · 3 comments
Closed

Comments

@Madhav-Kanda
Copy link
Contributor

The documentation for LocScaleReparam in case of centeredness mentions that: " If None (default) learn a per-site per-element centering parameter in [0,1]", but upon looking through the implementation of the algorithm, I did not find anything relevant to finding the correct centeredness (as far as I understand the implementation by default it takes the centeredness to be 0.5 and does not optimize it).

Example:
In the case of the eight schools' examples, the ideal centeredness should be close to 0. Still, upon using LocScaleReparam with None as its parameter, we get the centeredness is 0.5, as seen below (which I believe is hard coded in the implementation of LocScaleReparam itself for the case of None).

J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools_noncentered(J, sigma, lambd, y=None):
    mu = numpyro.sample('mu', dist.Normal(2, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=lambd)}):
            theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
nuts_kernel = NUTS(eight_schools_noncentered)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, lambd = None,y=y, extra_fields=('potential_energy',))
mcmc.print_summary()
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(rng_key, J, sigma, lambd = 0.5,y=y, extra_fields=('potential_energy',))
mcmc.print_summary()

In the above case both of them provide the same print_summary. Thus, I believe that either the code for the LocScaleReparam should be updated to include SVI for finding the best centeredness or the documentation should be updated accordingly.

@fehiepsi
Copy link
Member

fehiepsi commented Jun 1, 2023

does not optimize it

You are running MCMC, rather than optimizing. As mentioned in the forum thread, you can use lift handler to cast the param statement to a sample statement.

@fehiepsi fehiepsi changed the title Bug in LocScaleReparam | Correction in Documentation Clarify how to sample the centered parameter in LocScaleReparam Jun 1, 2023
@fehiepsi
Copy link
Member

fehiepsi commented Jun 1, 2023

Please feel free to submit a PR to clarify the doc. :)

@Madhav-Kanda
Copy link
Contributor Author

Sure, I will do that.

Madhav-Kanda added a commit to Madhav-Kanda/numpyro that referenced this issue Jun 1, 2023
fehiepsi pushed a commit that referenced this issue Jun 1, 2023
* Update reparam.py

Refer #1598

* Making suggested changes
@fehiepsi fehiepsi closed this as completed Jun 2, 2023
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this issue Jun 2, 2023
* Update reparam.py

Refer pyro-ppl#1598

* Making suggested changes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants