Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the provenance logic to prepare for the removal of jax named shape #1837

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p
import jax.numpy as jnp


def eval_provenance(fn, **kwargs):
Expand Down Expand Up @@ -53,18 +52,11 @@ def eval_provenance(fn, **kwargs):
# get provenances of flatten kwargs
aval_kwargs = {}
for n, v in kwargs.items():
aval = jax.ShapeDtypeStruct((), jnp.bool_, {"provenance": frozenset({n})})
aval_kwargs[n] = jax.tree.map(lambda _: aval, v)
aval_args, _ = jax.tree.flatten(((), aval_kwargs))
provenance_inputs = jax.tree.map(lambda x: x.named_shape["provenance"], aval_args)
aval_kwargs[n] = jax.tree.map(lambda _: frozenset({n}), v)
provenance_inputs, _ = jax.tree.flatten(((), aval_kwargs))

provenance_outputs = track_deps_jaxpr(jaxpr, provenance_inputs)
out_flat = []
for v, p in zip(avals_out, provenance_outputs):
val = jax.ShapeDtypeStruct(jnp.shape(v), jnp.result_type(v), {"provenance": p})
out_flat.append(val)
out = jax.tree.unflatten(out_tree(), out_flat)
return jax.tree.map(lambda x: x.named_shape["provenance"], out)
return jax.tree.unflatten(out_tree(), provenance_outputs)


def track_deps_jaxpr(jaxpr, provenance_inputs):
Expand Down
Loading