-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pseudo-Random Number Generation For None Input For JAX PRNGKey In Metaclass #192
Pseudo-Random Number Generation For None Input For JAX PRNGKey In Metaclass #192
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #192 +/- ##
==========================================
+ Coverage 94.73% 94.76% +0.02%
==========================================
Files 40 40
Lines 893 898 +5
==========================================
+ Hits 846 851 +5
Misses 47 47
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
In model1.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.PRNGKey(272),
observed_hosp_admissions=model1_samp.sampled_observed_hosp_admissions,
) |
I grandfathered in the type def run(
self,
num_warmup,
num_samples,
rng_key: jax.random.PRNGKey | None = None,
nuts_args: dict = None,
mcmc_args: dict = None,
**kwargs,
) -> None: However, the tests fail here, because of
|
Might be a good time to switch to the new style of random keys https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#jep-9263 |
Some links:
For the following,
|
NOTE: At some point, having |
Can we add a test that checks that two models with no rng key do not produce the same results? |
Clever and nice. Co-authored-by: Damon Bayer <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid so much repeated code, I suggest writing some helper functions.
One to create the model, so you can have
model_1 = create_model()
model_2 = create_model()
and one to sample from the model so you can do
sample_model(model_1, rng_key = jr.key(54))
sample_model(model_2, rng_key = jr.key(54))
In
metaclass.py
, the MCMC run require a pseudo-random key (jax.random.PRNGKey()
). Without this PR, if the user does not specify their own PRNGKey, MSR instantiates its own PRNGKey using a magic number (54). With this pull request, a random integer between 0 and 100000 is generated and this integer is then used for the PRNGKey when, as a function argument, the PRNGKey is None.