From be6a0e7d36c5a7dbaa18a5a42e7b9d09cff879a2 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 3 Feb 2022 20:03:37 -0500 Subject: [PATCH] Fix provenance logic for jax 0.2.28 --- numpyro/ops/provenance.py | 3 +++ setup.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py index 2b6f9fbbf..cc2812c48 100644 --- a/numpyro/ops/provenance.py +++ b/numpyro/ops/provenance.py @@ -26,6 +26,9 @@ def process_primitive(self, primitive, tracers, params): out_tracers = out_tracers if primitive.multiple_results else [out_tracers] for t in out_tracers: t.aval.named_shape["_provenance"] = out_provenance + # Also update provenance of the cached tracer -> aval dict. + aval_cache = self.frame.tracer_to_var[id(t)].aval + aval_cache.named_shape["_provenance"] = out_provenance out_tracers = out_tracers if primitive.multiple_results else out_tracers[0] return out_tracers diff --git a/setup.py b/setup.py index 9babf2d3c..03f32abe6 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ from setuptools import find_packages, setup PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) -_jax_version_constraints = ">=0.2.13,<0.2.28" +_jax_version_constraints = ">=0.2.13" _jaxlib_version_constraints = ">=0.1.65" # Find version