Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added render_params argument in render_model #1381

Merged
merged 5 commits into from
Apr 4, 2022

Conversation

karm-patel
Copy link
Contributor

Hi @fehiepsi and team, regarding issue #1379, I've made the following changes,

  1. Added render_params argument in render_model
  2. Added keys sample_param and param_constraint in dictionary returned by get_model_relations()
  3. Enabled Provenance tracking for params in get_model_relations to add values in sample_param.
  4. Added example regarding the use of render_params in the tutorial model_rendering.ipynb

@fehiepsi, I would appreciate your code review!

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

}

params = {
name: ProvenanceArray(site["value"], frozenset({name}))
Copy link
Member

@fehiepsi fehiepsi Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In NumPyro, param's value can be a (nested) dictionary (e.g. those of networks' parameters). I think we need some assertation either here or in ProvenanceArray (that the data needs to have attributes shape and dtype). If you want to address that issue, I think we can use

name: jax.tree_util.tree_map(lambda x: ProvenanceArray(x, frozenset({name})), site["value"])

here. But we need to add test or an example in the notebook to make sure that it works. E.g. a similar model for the one in your notebook would be

import flax.linen as nn
from numpyro.contrib.module import flax_module

def model(data):
    lambda_base = numpyro.sample("lambda", dist.Normal(0, 1))
    net = flax_module("affine_net", flax.nn.Dense(1), input_shape=(1,))
    lambda = jnp.exp(net(jnp.expand_dims(lambda_base, -1)).squeeze(-1))
    with numpyro.plate("N", len(data)):
        numpyro.sample("obs", dist.Exponential(lambd), obs=data)

In the plot, I would expect to have lambda sample node and affine_net$params param node point to obs observed node.

Copy link
Contributor Author

@karm-patel karm-patel Apr 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting!, Thanks @fehiepsi for pointing out this issue and giving the solution. I have added that line. Please check following output

def model(data):
    lambda_base = numpyro.sample("lambda", dist.Normal(0, 1))
    net = flax_module("affine_net", flax_nn.Dense(1), input_shape=(1,))
    lambd = jnp.exp(net(jnp.expand_dims(lambda_base, -1)).squeeze(-1))
    with numpyro.plate("N", len(data)):
        numpyro.sample("obs", dist.Exponential(lambd), obs=data)

numpyro.render_model(model,model_args=(data,),render_distributions=True, render_params=True)

image

But I'm not sure where to add this example, so I have added this example in model_rendering.ipynb with the title "Rendering neural network's parameters". Feel free to suggest any edits.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoa, it is actually working. Thanks, @karm-patel! Do you think it is better to strip out the$params part in render_model (at cur_graph.node(rv, label=rv, shape=shape, style="filled", fillcolor=color) line)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I removed it, @fehiepsi please check it.

notebooks/source/model_rendering.ipynb Outdated Show resolved Hide resolved
" x = numpyro.sample(\"x\", dist.Normal(0, 1))\n",
" y = numpyro.sample(\"y\", dist.LogNormal(x, 1))\n",
" m = numpyro.param(\"m\", jnp.array(0))\n",
" sd = numpyro.param(\"sd\", jnp.array([1]), constraint=constraints.positive)\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can set init values for m and sd to 0. and 1.. You might want to change those lines to

        self.shape = jnp.shape(data)
        self.dtype = jnp.dtype(data)

for it to work.

Copy link
Contributor Author

@karm-patel karm-patel Apr 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, done.
Edit:
if I keep jnp.dtype(data) then I think data should be jnp.array(). Because in test_provenance.py, numpy array is passed and that test case was failing. Hence I reverted to data.dtype

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry, I think you can use jnp.result_type(data) here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool! I changed it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also update the model in the notebook (jnp.array(0.) -> 0.) and rerun it for the new $params removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I updated it and re-ran the notebook.

fehiepsi
fehiepsi previously approved these changes Apr 2, 2022
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! This is super useful, especially when we use neural networks. We can see clear that which networks are used for some random variables. There are two nits as commented above.

@karm-patel karm-patel requested a review from fehiepsi April 3, 2022 18:01
@fehiepsi fehiepsi merged commit 1d4bba1 into pyro-ppl:master Apr 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants