-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Raise an error if there is no common scale when model enumerated #1536
Conversation
numpyro/infer/elbo.py
Outdated
else None | ||
scales_set = set( | ||
[ | ||
model_trace[name]["scale"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how to check the uniqueness of jax array scales because jax arrays is not hashable. Could you replace this by something like
scale = None
for name in (group_names | group_sum_vars):
site_scale = model_trace[name]["scale"]
if isinstance(scale, (int, float)) and isinstance(site_scale, (int, float, type(None))) and (site_scale != scale):
raise ValueError(...)
scale = site_scale
Btw, does this mean that we don't support enumeration for models with both global and local variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied the implementation from Pyro which first checks that array is scalar, converts it to scalar, and then compares them. Does it look alright to you?
https://github.com/pyro-ppl/pyro/blob/dev/pyro/infer/traceenum_elbo.py#L37-L41
Btw, does this mean that we don't support enumeration for models with both global and local variables?
Yeah, if you enumerate a global variable than you cannot subsample a local variable that depends on it.
I'm working on it but that will require first changing funsor.sum_product.sum_product
to handle plate-wise scaling there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we can't convert a tracer to float, float(site_scale)
will fail when scale is a tracer. Maybe we need to raise error if it is a jnp.array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Changed it so that it raises an error if scale is jnp.array
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks, Yerdos.
Thanks for reviewing @fehiepsi . |
It turns out that current subsample scaling logic is not handled correctly in
TraceEnum_ELBO
(sorry for that). I added some tests to demonstrate that. I think the best place to handle subsampling scaling is infunsor.sum_product.sum_product
. For now I added a check for a common scale which raises an error if there is no common scale.