Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

filter out tests waiting for next tfp release #1817

Merged
merged 7 commits into from
Jun 22, 2024

Conversation

juanitorduz
Copy link
Contributor

Partially addresses #1814 . We must keep in mind removing these skip statements once we see a new release.

@juanitorduz
Copy link
Contributor Author

juanitorduz commented Jun 19, 2024

There is a new test failing probably because the new jax and/or numpy releases

__________________ test_discrete_site_without_infer_enumerate __________________

    def test_discrete_site_without_infer_enumerate():
        def model():
            numpyro.sample("x", dist.Bernoulli(0.5))
    
        mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
        with pytest.warns(FutureWarning, match="enumerated sites"):
>           mcmc.run(random.PRNGKey(0))
E           FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.

test/infer/test_mcmc.py:1104: FutureWarning

I added a different match group in aa30c69 but I think it is essential to address these warnings. Especially because we are also getting

DeprecationWarning: numpy.core.numeric is deprecated and has been renamed to numpy._core.numeric. The numpy._core namespace contains private NumPy internals and its use is discouraged, as NumPy internals can change without warning in any release. In practice, most real-world usage of numpy.core is to access functionality in the public NumPy API. If that is the case, use the public NumPy API. If not, you are using NumPy internals. If you would still like to access an internal attribute, use numpy._core.numeric.normalize_axis_tuple.
    from numpy.core.numeric import normalize_axis_tuple

test/infer/test_mcmc.py::test_discrete_site_without_infer_enumerate
  /Users/juanitorduz/Documents/envs/numpyro-env/lib/python3.12/site-packages/jax/_src/linear_util.py:192: DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is deprecated. Please use 'x', 'min', and 'max' respectively instead.

@juanitorduz
Copy link
Contributor Author

It seems

FutureWarning: unhashable type: <class 'jax._src.interpreters.batching.BatchTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.

is all over the place now 😓

@fehiepsi
Copy link
Member

Thanks @juanitorduz. I'm looking at them.

@juanitorduz
Copy link
Contributor Author

Thanks @juanitorduz. I'm looking at them.

ok! You can either push to this branch or create a new one if needed

@fehiepsi
Copy link
Member

It turns out that in funsor, we have some checks for tracers to be Hashable. I don't think that the new behavior will cause issues: it is fine to let arrays to be either hashable or unhashable. So I think we can simply filter out these warnings:

  • in __init__, at the top of the file:
import warnings

warnings.filterwarnings("ignore", message=".*Attempting to hash a tracer.*", category=FutureWarning)
  • in pyproject.toml: "ignore:.*Attempting to hash a tracer:FutureWarning", for pytest

@juanitorduz
Copy link
Contributor Author

ok! We are making progress 😅 ! Now we have

FAILED test/contrib/test_control_flow.py::test_scan - TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[10])', 'ShapedArray(float32[10])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])'], []).
FAILED test/contrib/test_control_flow.py::test_scan_svi - TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[3,5])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[5])'], []).

and

TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[10])'], []).
test/test_examples.py::test_cpu[holt_winters.py --T 4 --num-samples 10 --num-warmup 10 --num-chains 2] Running:
python examples/holt_winters.py --T 4 --num-samples 10 --num-warmup 10 --num-chains 2

I am so seeing

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

🤔

@juanitorduz
Copy link
Contributor Author

It seems like a type problem ... could this be again a numpy or jax recent change?

numpyro/infer/__init__.py Outdated Show resolved Hide resolved
@fehiepsi
Copy link
Member

Hi @juanitorduz, I raised the upstream error in jax-ml/jax#22045. For a fix, could you help me change every device_put(foo) in https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/control_flow/scan.py to

tree_map(device_put, foo)

@juanitorduz
Copy link
Contributor Author

In 19f8232 I still saw other test failing:

FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-MetropolisAdjustedLangevinAlgorithm-kwargs0] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-SliceSampler-kwargs2] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-UncalibratedLangevin-kwargs3] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-MetropolisAdjustedLangevinAlgorithm-kwargs0] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-SliceSampler-kwargs2] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-UncalibratedLangevin-kwargs3] - OverflowError: Python int too large to convert to C long

Hence: 038dd84

@juanitorduz
Copy link
Contributor Author

Ok! Finally is 🟢!

Shall we revert these changes once the JAX issue is fixed and released? Similarly for tfp next release?

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the issues, @juanitorduz!!

@fehiepsi fehiepsi merged commit 9785376 into pyro-ppl:master Jun 22, 2024
4 checks passed
@juanitorduz juanitorduz deleted the tfp_tests_skip branch June 30, 2024 14:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants