-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add inference utilities to transform between unconstrained and constrained space #1564
Conversation
numpyro/infer/util.py
Outdated
names. | ||
:return: `dict` of transformation keyed by site names. | ||
""" | ||
transforms = get_transforms(model, model_args, model_kwargs, params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here params
is in unconstrained space, so we can't substitute natively. If you want to deal with params, please feel free to adjust constrain_fn
for it. Maybe also rename unconstrain_values
to unconstrain_fn
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed you're right. Let me know if your points have been addressed by the latest commit :)
numpyro/infer/util.py
Outdated
@@ -157,7 +157,8 @@ def transform_fn(transforms, params, invert=False): | |||
return {k: transforms[k](v) if k in transforms else v for k, v in params.items()} | |||
|
|||
|
|||
def constrain_fn(model, model_args, model_kwargs, params, return_deterministic=False): | |||
def constrain_fn(model, model_args, model_kwargs, params, | |||
include_param_sites=False, return_deterministic=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove those include_param_sites
flags and keep the True behavior at all functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was for ensuring backward compatibility. I removed these in the last comit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I guess it won't affect the old behavior. If users provide param sites, your solution will return expected results, while the current master branch will raise an error or skip them - so this is an improvement.
substituted_model = substitute(model, data=params) | ||
transforms, _, _, _ = _get_model_transforms(substituted_model, model_args, model_kwargs) | ||
return transforms | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you run make format
to fix lint issue? I guess you need to add a new line here
…ained space Improve and simplify constrain_fn and unconstrain_fn implementation Add missing doctstrings Constrain/unconstrain functions now always consider param sites Fix syntax for lint tests Fix syntax for lint tests Fix syntax for lint tests
@fehiepsi fyi I squashed the commits of this PR as I had to do lots of syntax fixing recently... should be ok now |
Thanks @aymgal! It's great to have this utility available. |
This is a proposal to fix #1554