Skip to content

Commit

Permalink
Add fixes for jax 0.4 (#601)
Browse files Browse the repository at this point in the history
* Add fixes for jax 0.4

* Add restriction

* Require ci to test jax with python 3.7
  • Loading branch information
fehiepsi authored Dec 17, 2022
1 parent 1b9b68d commit 0294f5e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7]
python-version: [3.7,3.8]
env:
CI: 1
FUNSOR_BACKEND: jax
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
ipython<=8.6.0 # restrict for https://github.com/ipython/ipython/issues/13845
makefun
multipledispatch
nbsphinx==0.8.1
Expand Down
18 changes: 14 additions & 4 deletions funsor/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import re

from jax.core import Tracer
from jax.interpreters.xla import DeviceArray
from jax.numpy import ndarray

from funsor.tensor import tensor_to_funsor
from funsor.terms import to_funsor
Expand All @@ -14,15 +16,23 @@
del _ # flake8


to_funsor.register(DeviceArray)(tensor_to_funsor)
to_funsor.register(ndarray)(tensor_to_funsor)
to_funsor.register(Tracer)(tensor_to_funsor)


@quote.register(DeviceArray)
@quote.register(ndarray)
def _quote(x, indent, out):
"""
Work around JAX's DeviceArray not supporting reproducible repr.
Work around JAX's ndarray not supporting reproducible repr.
"""
# After JAX 0.4, jnp.ones(3) is no longer a DeviceArray, but an ndarray.
# In addition, a tracer is also an ndarray - so we need to handler it
# separately here.
if isinstance(x, Tracer):
# Default implementation.
line = re.sub("\n\\s*", " ", repr(x))
out.append((indent, line))
return
if x.size >= quote.printoptions["threshold"]:
data = "..." + " x ".join(str(d) for d in x.shape) + "..."
else:
Expand Down

0 comments on commit 0294f5e

Please sign in to comment.