-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Porting the GMM tutorial #1562
Porting the GMM tutorial #1562
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Hmm ... failing tests ( UPDATE: tests passed after rerunning them. |
@ordabayevy You can use the follow hook to collect gradient norms from collections import defaultdict
import optax
import jax
import jax.numpy as jnp
def hook_optax(optimizer):
gradient_norms = defaultdict(list)
def append_grad(grad):
for name, g in grad.items():
gradient_norms[name].append(float(jnp.linalg.norm(g)))
return grad
def update_fn(grads, state, params=None):
grads = jax.pure_callback(append_grad, grads, grads)
return optimizer.update(grads, state, params=params)
return optax.GradientTransformation(optimizer.init, update_fn), gradient_norms
optimizer, gradient_norms = hook_optax(optax.adam(0.1))
params = {"x": 0.}
opt_state = optimizer.init(params)
def loss(params):
return (params["x"] - 2.) ** 2
@jax.jit
def step(params, opt_state):
loss_value, grads = jax.value_and_grad(loss)(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i in range(100):
params, opt_state, loss_value = step(params, opt_state)
params, gradient_norms |
Thanks @fehiepsi! Updated the tutorial to include gradient norms. Here are the plots I get: The first plot deviates a bit from the Pyro tutorial one, but maybe it is okay? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! Thanks, Yerdos!
notebooks/source/gmm.ipynb
Outdated
{ | ||
"data": { | ||
"text/plain": [ | ||
"Text(0.5, 1.0, 'Convergence of SVI')" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add plt.show()
at the end of each cell block to remove those lines.
notebooks/source/gmm.ipynb
Outdated
"\n", | ||
"\n", | ||
"optim, gradient_norms = hook_optax(optax.adam(learning_rate=0.1, b1=0.8, b2=0.99))\n", | ||
"global_svi = SVI(model, global_guide, optax_to_numpyro(optim), loss=elbo)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need optax_to_numpyro
? SVI should handle it under the hood.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't know that!
value, name_to_dim=node["infer"]["name_to_dim"] | ||
) | ||
else: | ||
if node["infer"].get("enumerate") == "parallel": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic is different the previous one. Could you comment for this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are non-enumerated latent variables in the GMM tutorial that trigger the else
branch here but since these variables are not enumerated _get_support_value(log_measure, name)
will give an error.
I think we only need to get support values for enumerated variables. For other variables adjoint should just return their sampled/observed values. I don't know what has changed in funsor since this code was written but it seems like adjoint handles well those cases under the hood (see # TODO this should really be handled entirely under the hood by adjoint
comment).
Porting the GMM tutorial from Pyro. I was able to reproduce everything except tracking gradient norms with
.register_hook
. I had to adjustinfer_discrete
so it would work with models that have non-enumerated variables (global variables in this case).