-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added render_params
argument in render_model
#1381
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
numpyro/contrib/render.py
Outdated
} | ||
|
||
params = { | ||
name: ProvenanceArray(site["value"], frozenset({name})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In NumPyro, param
's value can be a (nested) dictionary (e.g. those of networks' parameters). I think we need some assertation either here or in ProvenanceArray (that the data needs to have attributes shape
and dtype
). If you want to address that issue, I think we can use
name: jax.tree_util.tree_map(lambda x: ProvenanceArray(x, frozenset({name})), site["value"])
here. But we need to add test or an example in the notebook to make sure that it works. E.g. a similar model for the one in your notebook would be
import flax.linen as nn
from numpyro.contrib.module import flax_module
def model(data):
lambda_base = numpyro.sample("lambda", dist.Normal(0, 1))
net = flax_module("affine_net", flax.nn.Dense(1), input_shape=(1,))
lambda = jnp.exp(net(jnp.expand_dims(lambda_base, -1)).squeeze(-1))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Exponential(lambd), obs=data)
In the plot, I would expect to have lambda
sample node and affine_net$params
param node point to obs
observed node.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting!, Thanks @fehiepsi for pointing out this issue and giving the solution. I have added that line. Please check following output
def model(data):
lambda_base = numpyro.sample("lambda", dist.Normal(0, 1))
net = flax_module("affine_net", flax_nn.Dense(1), input_shape=(1,))
lambd = jnp.exp(net(jnp.expand_dims(lambda_base, -1)).squeeze(-1))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Exponential(lambd), obs=data)
numpyro.render_model(model,model_args=(data,),render_distributions=True, render_params=True)
But I'm not sure where to add this example, so I have added this example in model_rendering.ipynb
with the title "Rendering neural network's parameters". Feel free to suggest any edits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoa, it is actually working. Thanks, @karm-patel! Do you think it is better to strip out the$params
part in render_model
(at cur_graph.node(rv, label=rv, shape=shape, style="filled", fillcolor=color)
line)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I removed it, @fehiepsi please check it.
" x = numpyro.sample(\"x\", dist.Normal(0, 1))\n", | ||
" y = numpyro.sample(\"y\", dist.LogNormal(x, 1))\n", | ||
" m = numpyro.param(\"m\", jnp.array(0))\n", | ||
" sd = numpyro.param(\"sd\", jnp.array([1]), constraint=constraints.positive)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can set init values for m
and sd
to 0.
and 1.
. You might want to change those lines to
self.shape = jnp.shape(data)
self.dtype = jnp.dtype(data)
for it to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, done.
Edit:
if I keep jnp.dtype(data)
then I think data should be jnp.array()
. Because in test_provenance.py, numpy array is passed and that test case was failing. Hence I reverted to data.dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, sorry, I think you can use jnp.result_type(data)
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool! I changed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also update the model in the notebook (jnp.array(0.) -> 0.) and rerun it for the new $params removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I updated it and re-ran the notebook.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! This is super useful, especially when we use neural networks. We can see clear that which networks are used for some random variables. There are two nits as commented above.
Hi @fehiepsi and team, regarding issue #1379, I've made the following changes,
render_params
argument inrender_model
sample_param
andparam_constraint
in dictionary returned byget_model_relations()
get_model_relations
to add values insample_param
.render_params
in the tutorialmodel_rendering.ipynb
@fehiepsi, I would appreciate your code review!