Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling breaks on JAX 0.2.0 #36

Closed
rlouf opened this issue Sep 29, 2020 · 4 comments
Closed

Sampling breaks on JAX 0.2.0 #36

rlouf opened this issue Sep 29, 2020 · 4 comments
Labels
bug Something isn't working

Comments

@rlouf
Copy link
Owner

rlouf commented Sep 29, 2020

While everything runs fine on v0.1.77, running sampling with JAX 0.2.0 returns the following error:

    @jax.jit
    def update_chains(rng_key, parameters, chain_state):
>       kernel = self.kernel_factory(*parameters)
E       jax.traceback_util.FilteredStackTrace: TypeError: <class 'function'> is not a valid JAX type
E       
E       The stack trace above excludes JAX-internal frames.
E       The following is the original exception that occurred, unmodified.
E       
E       --------------------

mcx/sampling.py:207: FilteredStackTrace

The above exception was the direct cause of the following exception:

mcx/sampling.py:245: in run
    keys, self.parameters, state
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/traceback_util.py:137: in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/api.py:1220: in batched_fun
    axis_name=axis_name)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/batching.py:36: in batch
    return batched_fun.call_wrapped(*in_vals)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/linear_util.py:151: in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/traceback_util.py:137: in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/api.py:215: in f_jitted
    donated_invars=donated_invars)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1144: in bind
    return call_bind(self, fun, *args, **params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1135: in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1147: in process
    return trace.process_call(self, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/batching.py:171: in process_call
    vals_out = call_primitive.bind(f, *vals, **params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1144: in bind
    return call_bind(self, fun, *args, **params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1135: in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1147: in process
    return trace.process_call(self, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/partial_eval.py:940: in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/partial_eval.py:1005: in trace_to_subjaxpr_dynamic
    out_tracers = map(trace.full_raise, ans)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/util.py:35: in safe_map
    return list(map(f, *args))
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:358: in full_raise
    return self.pure(val)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/partial_eval.py:897: in new_const
    aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_python_scalar(val))
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:821: in get_aval
    return concrete_aval(x)
@rlouf rlouf added the bug Something isn't working label Sep 29, 2020
@rlouf
Copy link
Owner Author

rlouf commented Sep 29, 2020

Raised an issue on the JAX repo: jax-ml/jax#4416 (comment)

@lmmx
Copy link
Contributor

lmmx commented Oct 19, 2020

Just to follow up this issue — the JAX team responded on the thread, in short saying it’s a deliberate change:

This error arises when the output of a jitted function is not a pytree of valid jax types, i.e. not a pytree of arrays (notice the out_tracers = map(trace.full_raise, ans) line in the traceback). In particular, it looks like a Python callable is being returned from a jitted function.

We never intended to support returning non-jaxtype values (indeed the jit docstring says the decorated function's "arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof."). But before 0.2.0 used to "work" accidentally.

I say "work" in quotes because:

  1. any returned non-jaxtype values had to be constants, i.e. could not depend on the values of the jitted function arguments, and
  2. if those returned non-jaxtype values contained any JAX tracers, e.g. by being functions with tracers in their closure or being other kinds of non-pytree containers, those leaked tracers would cause mysterious and opaque errors downstream.
    In 0.2.0 we stopped supporting that never-intended and opaque-bug-prone behavior.

and then as for what to do (emphasis added):

However, due to reason (1) above, in any case where this used to work, it shouldn't be too hard to revise the code not to return the function-valued arguments, since they were constants anyway and so didn't need to be returned from the jitted function.

and they’ve closed the issue (awaiting any further questions)

@rlouf
Copy link
Owner Author

rlouf commented Oct 19, 2020

Thanks @lmmx! I did not respond yet as I did not have time to look into it and did not want to pollute their issue tracker. Will reopen if needed when I get to that. In the meantime I pinned the version to the previous version so I can keep testing stuff.

@rlouf
Copy link
Owner Author

rlouf commented Nov 20, 2020

I think that the problem here stems from the fact that kernel_factory is captured by update_chain from the enclosing scope, and JIT-compiled when update_chain is JIT-compiled. I can't see any other explanation as kernel_factory is not decorated by @jax.jit.

If that is the issue the solution is simple, we have to pass kernel_factory as an argument to update_chain as follows (note that the parameters are constant as well here, so we can can specify this to the jit-compiler):

@functools.partial(jax.jit, static_argnums=(1,2))
def update_chain(rng_key, kernel_factory, parameters, chain_state):
    kernel = self.kernel_factory(*parameters)
    new_chain_state, info = kernel(rng_key, chain_state)
    return new_chain_state, info

And later in the update_loop function which advances all the chains:

@functools.partial(jax.jit, static_argnums=(2,3))
def update_loop(state, key, kernel_factory, parameters):
    keys = jax.random.split(key, num_chains)
    state, info = jax.vmap(kernel, in_axes=(0, None, 0, 0))(keys, kernel_factory, parameters, state)
    return state, info, mcx_ravel_pytree((state, info))[0]

Which is slightly more verbose, but it would make sense to gather kernel_factory and parameters in a NamedTuple if this becomes too verbose.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants