-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sample from distribution without storing #1790
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice solution, thanks @amifalk!
numpyro/infer/mcmc.py
Outdated
if field_name.startswith(f"~{self._sample_field}."): | ||
remove_sites.append(field_name[len(self._sample_field) + 2 :]) | ||
else: | ||
collect_fields.append(field_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, it seems that making collect_fields
a set is slightly better (to avoid collecting duplicating fields). Maybe setting collect_fields=tuple(set(collect_fields))
below? Or using dict like what you did before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set is not insertion order preserving and the current solution relies on self._sample_field being the first item in collect_fields so I'll revert to the dictionary approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful work! Thanks, @amifalk!
* exclude sample sites with "~" * handle repeat remove_sites * test exclude sites * fix test case and len(1) collect_fields edge case * add dict check, switch to list, add documentation * back to dict solution
As discussed in #1695. Users can exclude a sample site from collection by adding
~{sample_field}.{site_name}
toextra_fields
inMCMC.run
orMCMC.warmup
.Because model initialization doesn't happen until
_single_chain_mcmc
is called and the logic to set up collect_fields happens before that, I opted not to make default_fields mutable/settable withinfer
.