Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting the GMM tutorial #1562

Merged
merged 6 commits into from
Mar 27, 2023
Merged

Porting the GMM tutorial #1562

merged 6 commits into from
Mar 27, 2023

Conversation

ordabayevy
Copy link
Member

Porting the GMM tutorial from Pyro. I was able to reproduce everything except tracking gradient norms with .register_hook. I had to adjust infer_discrete so it would work with models that have non-enumerated variables (global variables in this case).

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ordabayevy
Copy link
Member Author

ordabayevy commented Mar 24, 2023

Hmm ... failing tests (test/infer/test_mcmc.py::test_binomial_stable_x64 and test/infer/test_mcmc.py::test_beta_bernoulli_x64[SA]) pass locally on my computer.

UPDATE: tests passed after rerunning them.

@fehiepsi
Copy link
Member

fehiepsi commented Mar 25, 2023

@ordabayevy You can use the follow hook to collect gradient norms

from collections import defaultdict
import optax
import jax
import jax.numpy as jnp

def hook_optax(optimizer):
    gradient_norms = defaultdict(list)

    def append_grad(grad):
        for name, g in grad.items():
            gradient_norms[name].append(float(jnp.linalg.norm(g)))
        return grad

    def update_fn(grads, state, params=None):
        grads = jax.pure_callback(append_grad, grads, grads)
        return optimizer.update(grads, state, params=params)

    return optax.GradientTransformation(optimizer.init, update_fn), gradient_norms

optimizer, gradient_norms = hook_optax(optax.adam(0.1))
params = {"x": 0.}
opt_state = optimizer.init(params)

def loss(params):
    return (params["x"] - 2.) ** 2

@jax.jit
def step(params, opt_state):
    loss_value, grads = jax.value_and_grad(loss)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

for i in range(100):
    params, opt_state, loss_value = step(params, opt_state)
params, gradient_norms

@ordabayevy
Copy link
Member Author

Thanks @fehiepsi! Updated the tutorial to include gradient norms. Here are the plots I get:
norms1
norms2

The first plot deviates a bit from the Pyro tutorial one, but maybe it is okay?

fehiepsi
fehiepsi previously approved these changes Mar 26, 2023
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Thanks, Yerdos!

{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Convergence of SVI')"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add plt.show() at the end of each cell block to remove those lines.

"\n",
"\n",
"optim, gradient_norms = hook_optax(optax.adam(learning_rate=0.1, b1=0.8, b2=0.99))\n",
"global_svi = SVI(model, global_guide, optax_to_numpyro(optim), loss=elbo)"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need optax_to_numpyro? SVI should handle it under the hood.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't know that!

value, name_to_dim=node["infer"]["name_to_dim"]
)
else:
if node["infer"].get("enumerate") == "parallel":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is different the previous one. Could you comment for this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are non-enumerated latent variables in the GMM tutorial that trigger the else branch here but since these variables are not enumerated _get_support_value(log_measure, name) will give an error.

I think we only need to get support values for enumerated variables. For other variables adjoint should just return their sampled/observed values. I don't know what has changed in funsor since this code was written but it seems like adjoint handles well those cases under the hood (see # TODO this should really be handled entirely under the hood by adjoint comment).

@fehiepsi fehiepsi merged commit a7267d9 into master Mar 27, 2023
@fehiepsi fehiepsi deleted the gmm branch March 27, 2023 06:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants