Skip to content

Commit

Permalink
chore: disable everytime and enable x64 for power laws
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash committed Sep 11, 2024
1 parent f1da2d5 commit 42ed59d
Showing 1 changed file with 17 additions and 34 deletions.
51 changes: 17 additions & 34 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,10 +1210,9 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape):
jax_dist = jax_dist_cls(*params)
# Enable 64bit support for higher accuracy
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
rng_key = random.PRNGKey(0)
expected_shape = prepend_shape + jax_dist.batch_shape + jax_dist.event_shape
samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
Expand Down Expand Up @@ -1260,10 +1259,9 @@ def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape):
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_infer_shapes(jax_dist, sp_dist, params):
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
shapes = []
for param in params:
if param is None:
Expand All @@ -1289,10 +1287,9 @@ def test_infer_shapes(jax_dist, sp_dist, params):
)
def test_has_rsample(jax_dist, sp_dist, params):
jax_dist = jax_dist(*params)
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
masked_dist = jax_dist.mask(False)
indept_dist = jax_dist.expand_by([2]).to_event(1)
transf_dist = dist.TransformedDistribution(jax_dist, biject_to(constraints.real))
Expand Down Expand Up @@ -1347,10 +1344,9 @@ def test_sample_gradient(jax_dist, sp_dist, params):
"StudentT": ["df"],
}.get(jax_dist.__name__, [])

numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

dist_args = [
p
Expand Down Expand Up @@ -1447,10 +1443,9 @@ def test_jit_log_likelihood(jax_dist, sp_dist, params):
):
pytest.xfail(reason="non-jittable params")

numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

rng_key = random.PRNGKey(0)
samples = jax_dist(*params).sample(key=rng_key, sample_shape=(2, 3))
Expand All @@ -1469,10 +1464,9 @@ def log_likelihood(*params):
@pytest.mark.parametrize("prepend_shape", [(), (2,), (2, 3)])
@pytest.mark.parametrize("jit", [False, True])
def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
jit_fn = _identity if not jit else jax.jit
jax_dist = jax_dist(*params)

Expand Down Expand Up @@ -1536,10 +1530,9 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_entropy_scipy(jax_dist, sp_dist, params):
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

jax_dist = jax_dist(*params)

Expand All @@ -1561,10 +1554,9 @@ def test_entropy_scipy(jax_dist, sp_dist, params):
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL + BASE
)
def test_entropy_samples(jax_dist, sp_dist, params):
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

jax_dist = jax_dist(*params)

Expand Down Expand Up @@ -1610,10 +1602,9 @@ def test_mixture_log_prob():
)
@pytest.mark.filterwarnings("ignore:overflow encountered:RuntimeWarning")
def test_cdf_and_icdf(jax_dist, sp_dist, params):
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
d = jax_dist(*params)
if d.event_dim > 0:
pytest.skip("skip testing cdf/icdf methods of multivariate distributions")
Expand Down Expand Up @@ -1666,10 +1657,9 @@ def test_gof(jax_dist, sp_dist, params):
pytest.skip("EulerMaruyama skip test when event shape is non-trivial.")
if jax_dist is dist.ZeroSumNormal:
pytest.skip("skip gof test for ZeroSumNormal")
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

num_samples = 10000
if "BetaProportion" in jax_dist.__name__:
Expand Down Expand Up @@ -1700,10 +1690,9 @@ def test_gof(jax_dist, sp_dist, params):

@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DISCRETE)
def test_independent_shape(jax_dist, sp_dist, params):
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
d = jax_dist(*params)
batch_shape, event_shape = d.batch_shape, d.event_shape
shape = batch_shape + event_shape
Expand Down Expand Up @@ -1889,10 +1878,9 @@ def test_log_prob_gradient(jax_dist, sp_dist, params):
pytest.skip("we have separated tests for LKJCholesky distribution")
if jax_dist is _ImproperWrapper:
pytest.skip("no param for ImproperUniform to test for log_prob gradient")
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

rng_key = random.PRNGKey(0)
value = jax_dist(*params).sample(rng_key)
Expand Down Expand Up @@ -2123,10 +2111,9 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
):
pytest.skip(f"{jax_dist.__name__} is a function, not a class")

numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

dist_args = [p for p in inspect.getfullargspec(jax_dist.__init__)[0][1:]]

Expand Down Expand Up @@ -2806,10 +2793,9 @@ def test_generated_sample_distribution(
"{} sampling method taken from upstream, no need to"
"test generated samples.".format(jax_dist.__name__)
)
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

jax_dist = jax_dist(*params)
if sp_dist and not jax_dist.event_shape and not jax_dist.batch_shape:
Expand Down Expand Up @@ -2852,10 +2838,9 @@ def test_zero_inflated_enumerate_support():
@pytest.mark.parametrize("prepend_shape", [(), (2, 3)])
@pytest.mark.parametrize("sample_shape", [(), (4,)])
def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape):
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
jax_dist = jax_dist(*params)
new_batch_shape = prepend_shape + jax_dist.batch_shape
expanded_dist = jax_dist.expand(new_batch_shape)
Expand Down Expand Up @@ -2998,10 +2983,9 @@ def f(x, data):
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_dist_pytree(jax_dist, sp_dist, params):
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)

def f(x):
return jax_dist(*params)
Expand Down Expand Up @@ -3287,10 +3271,9 @@ def _tree_equal(t1, t2):
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_vmap_dist(jax_dist, sp_dist, params):
numpyro.enable_x64(False)
if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
numpyro.enable_x64()
else:
numpyro.enable_x64(False)
param_names = list(inspect.signature(jax_dist).parameters.keys())
vmappable_param_idxs = _get_vmappable_dist_init_params(jax_dist)
vmappable_param_idxs = vmappable_param_idxs[: len(params)]
Expand Down

0 comments on commit 42ed59d

Please sign in to comment.