Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pseudo-Random Number Generation For None Input For JAX PRNGKey In Metaclass #192

Merged
merged 11 commits into from
Jun 17, 2024

Conversation

AFg6K7h4fhy2
Copy link
Collaborator

In metaclass.py, the MCMC run require a pseudo-random key (jax.random.PRNGKey()). Without this PR, if the user does not specify their own PRNGKey, MSR instantiates its own PRNGKey using a magic number (54). With this pull request, a random integer between 0 and 100000 is generated and this integer is then used for the PRNGKey when, as a function argument, the PRNGKey is None.

@AFg6K7h4fhy2 AFg6K7h4fhy2 added the request New feature or request label Jun 14, 2024
@AFg6K7h4fhy2 AFg6K7h4fhy2 added this to the 🐺 Lycorhinus milestone Jun 14, 2024
@AFg6K7h4fhy2 AFg6K7h4fhy2 requested a review from damonbayer June 14, 2024 14:29
@AFg6K7h4fhy2 AFg6K7h4fhy2 self-assigned this Jun 14, 2024
@AFg6K7h4fhy2 AFg6K7h4fhy2 linked an issue Jun 14, 2024 that may be closed by this pull request
Copy link

codecov bot commented Jun 14, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 94.76%. Comparing base (657a206) to head (dd5de7e).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #192      +/-   ##
==========================================
+ Coverage   94.73%   94.76%   +0.02%     
==========================================
  Files          40       40              
  Lines         893      898       +5     
==========================================
+ Hits          846      851       +5     
  Misses         47       47              
Flag Coverage Δ
unittests 94.76% <100.00%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@AFg6K7h4fhy2
Copy link
Collaborator Author

In test_hospitalizations.py (and other test files) would people like me to also change the following behavior (example) or it's fine?

model1.run(
        num_warmup=500,
        num_samples=500,
        rng_key=jax.random.PRNGKey(272),
        observed_hosp_admissions=model1_samp.sampled_observed_hosp_admissions,
    )

@AFg6K7h4fhy2
Copy link
Collaborator Author

I grandfathered in the type jax.random.PRNGKey in

def run(
        self,
        num_warmup,
        num_samples,
        rng_key: jax.random.PRNGKey | None = None,
        nuts_args: dict = None,
        mcmc_args: dict = None,
        **kwargs,
    ) -> None:

However, the tests fail here, because of jax.random.PRNGKey. Inspection suggests that ArrayLike from from jax.typing import ArrayLike is the correct type to be using:

>>> import jax.random as jr
>>> rand_int_key = jr.PRNGKey(0)
>>> rand_int = jr.randint(
...                 rand_int_key, shape=(), minval=0, maxval=100000
...             )
>>> rng_key = jr.PRNGKey(rand_int)
>>> type(rng_key)
<class 'jaxlib.xla_extension.ArrayImpl'>

@AFg6K7h4fhy2 AFg6K7h4fhy2 marked this pull request as ready for review June 14, 2024 17:00
@damonbayer
Copy link
Collaborator

Might be a good time to switch to the new style of random keys https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#jep-9263

@AFg6K7h4fhy2 AFg6K7h4fhy2 marked this pull request as draft June 14, 2024 17:19
@AFg6K7h4fhy2
Copy link
Collaborator Author

Some links:

For the following, ArrayLike still works, but (see last link) KeyArray might be implemented as a type soon.

>>> import jax.random as jr
>>> import numpy as np
>>> rand_int = np.random.randint(0, 100000)
>>> rng_key = jr.key(rand_int)
>>> type(rng_key)
<class 'jax._src.prng.PRNGKeyArray'>

@AFg6K7h4fhy2 AFg6K7h4fhy2 marked this pull request as ready for review June 14, 2024 20:06
@AFg6K7h4fhy2
Copy link
Collaborator Author

NOTE: At some point, having rng_key: ArrayLike | None = None might be problematic in the sense that, while jax.random.key() and jax.random.PRNGKey() both are ArrayLike (for the time being, a KeyArray type in jax.typing might be made soon), the ArrayLike has to be from jax.random.key() or jax.random.PRNGKey() not from a typical JAX array, e.g. jax.numpy.array([2,3,1]).

@damonbayer
Copy link
Collaborator

Can we add a test that checks that two models with no rng key do not produce the same results?

Clever and nice.

Co-authored-by: Damon Bayer <[email protected]>
Copy link
Collaborator

@damonbayer damonbayer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid so much repeated code, I suggest writing some helper functions.

One to create the model, so you can have

model_1 = create_model()
model_2 = create_model()

and one to sample from the model so you can do

sample_model(model_1, rng_key = jr.key(54))
sample_model(model_2, rng_key = jr.key(54))

@damonbayer damonbayer merged commit 407bc4e into main Jun 17, 2024
8 checks passed
@damonbayer damonbayer deleted the 146-UPX3-remove-default-seed-value-from-run branch June 17, 2024 21:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
request New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove default seed value from run
2 participants