Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import error under JAX 4.1 #25

Closed
junpenglao opened this issue Dec 17, 2022 · 2 comments
Closed

Import error under JAX 4.1 #25

junpenglao opened this issue Dec 17, 2022 · 2 comments

Comments

@junpenglao
Copy link

Version: jax-0.4.1 jaxlib-0.4.1

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[1], line 1
----> 1 from oryx.core.ppl import random_variable

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/__init__.py:16
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """Oryx is a neural network mini-library built on top of Jax."""
---> 16 from oryx import bijectors
     17 from oryx import core
     18 from oryx import distributions

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/__init__.py:19
     16 import inspect
     18 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 19 from oryx.bijectors import bijector_extensions
     21 tfb = tfp.bijectors
     23 _bijectors = {}

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/bijector_extensions.py:28
     26 from six.moves import zip
     27 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 28 from oryx import core
     29 from oryx.core.interpreters import inverse
     31 safe_map = jax_util.safe_map

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/__init__.py:16
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """Contains core logic for Oryx classes."""
---> 16 from oryx.core import ppl
     17 from oryx.core.interpreters.inverse import ildj
     18 from oryx.core.interpreters.inverse import ildj_registry

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/__init__.py:17
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     14 # ============================================================================
     15 # Lint as: python3
     16 """Module for probabilistic programming features."""
---> 17 from oryx.core.ppl.transformations import conditional
     18 from oryx.core.ppl.transformations import graph_replace
     19 from oryx.core.ppl.transformations import intervene

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/transformations.py:204
    201 from jax import util as jax_util
    203 from oryx.core import primitive
--> 204 from oryx.core.interpreters import harvest
    205 from oryx.core.interpreters import log_prob as lp
    207 Program = Callable[..., Any]

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/interpreters/harvest.py:136
    134 from jax.interpreters import ad
    135 from jax.interpreters import batching
--> 136 from jax.interpreters import masking
    137 from jax.interpreters import xla
    138 from jax.lib.xla_bridge import xla_client as xc

ImportError: cannot import name 'masking' from 'jax.interpreters' (/opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/jax/interpreters/__init__.py)

In [2]: import oryx
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[2], line 1
----> 1 import oryx

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/__init__.py:16
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """Oryx is a neural network mini-library built on top of Jax."""
---> 16 from oryx import bijectors
     17 from oryx import core
     18 from oryx import distributions

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/__init__.py:19
     16 import inspect
     18 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 19 from oryx.bijectors import bijector_extensions
     21 tfb = tfp.bijectors
     23 _bijectors = {}

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/bijector_extensions.py:28
     26 from six.moves import zip
     27 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 28 from oryx import core
     29 from oryx.core.interpreters import inverse
     31 safe_map = jax_util.safe_map

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/__init__.py:16
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """Contains core logic for Oryx classes."""
---> 16 from oryx.core import ppl
     17 from oryx.core.interpreters.inverse import ildj
     18 from oryx.core.interpreters.inverse import ildj_registry

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/__init__.py:17
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     14 # ============================================================================
     15 # Lint as: python3
     16 """Module for probabilistic programming features."""
---> 17 from oryx.core.ppl.transformations import conditional
     18 from oryx.core.ppl.transformations import graph_replace
     19 from oryx.core.ppl.transformations import intervene

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/transformations.py:204
    201 from jax import util as jax_util
    203 from oryx.core import primitive
--> 204 from oryx.core.interpreters import harvest
    205 from oryx.core.interpreters import log_prob as lp
    207 Program = Callable[..., Any]

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/interpreters/harvest.py:136
    134 from jax.interpreters import ad
    135 from jax.interpreters import batching
--> 136 from jax.interpreters import masking
    137 from jax.interpreters import xla
    138 from jax.lib.xla_bridge import xla_client as xc

ImportError: cannot import name 'masking' from 'jax.interpreters' (/opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/jax/interpreters/__init__.py)
junpenglao added a commit to junpenglao/blackjax that referenced this issue Dec 17, 2022
example using oyrx is currently broken (jax-ml/oryx#25)
@sharadmv
Copy link
Collaborator

I bumped the Oryx version, PTAL and see if it works!

@junpenglao
Copy link
Author

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants