Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New TypeError: <class 'function'> is not a valid JAX type exception in JAX 0.2.0 #4416

Closed
rlouf opened this issue Sep 29, 2020 · 4 comments
Closed
Assignees
Labels
question Questions for the JAX team

Comments

@rlouf
Copy link

rlouf commented Sep 29, 2020

I am working on a probabilistic programming library that makes heavy use of JAX. Since the release of the 0.2.0 version, JAX returns the following exception when executing the sampling code:

@jax.jit
    def update_chains(rng_key, parameters, chain_state):
>       kernel = self.kernel_factory(*parameters)
E       jax.traceback_util.FilteredStackTrace: TypeError: <class 'function'> is not a valid JAX type
E       
E       The stack trace above excludes JAX-internal frames.
E       The following is the original exception that occurred, unmodified.
E       
E       --------------------

mcx/sampling.py:207: FilteredStackTrace

The above exception was the direct cause of the following exception:

mcx/sampling.py:245: in run
    keys, self.parameters, state
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/traceback_util.py:137: in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/api.py:1220: in batched_fun
    axis_name=axis_name)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/batching.py:36: in batch
    return batched_fun.call_wrapped(*in_vals)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/linear_util.py:151: in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/traceback_util.py:137: in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/api.py:215: in f_jitted
    donated_invars=donated_invars)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1144: in bind
    return call_bind(self, fun, *args, **params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1135: in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1147: in process
    return trace.process_call(self, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/batching.py:171: in process_call
    vals_out = call_primitive.bind(f, *vals, **params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1144: in bind
    return call_bind(self, fun, *args, **params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1135: in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1147: in process
    return trace.process_call(self, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/partial_eval.py:940: in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/partial_eval.py:1005: in trace_to_subjaxpr_dynamic
    out_tracers = map(trace.full_raise, ans)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/util.py:35: in safe_map
    return list(map(f, *args))
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:358: in full_raise
    return self.pure(val)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/partial_eval.py:897: in new_const
    aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_python_scalar(val))
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:821: in get_aval
    return concrete_aval(x)

I downgraded to 0.1.77 and it runs fine. Is that intended? If so, is there a workaround?

Note that I have already run into this exception in the past, but never for this part of the code.

@mattjj mattjj self-assigned this Sep 29, 2020
@mattjj mattjj added the question Questions for the JAX team label Sep 29, 2020
@mattjj
Copy link
Collaborator

mattjj commented Oct 3, 2020

Sorry for being slow to respond!

It is intended. Not that we intended to break your code, of course!

This error arises when the output of a jitted function is not a pytree of valid jax types, i.e. not a pytree of arrays (notice the out_tracers = map(trace.full_raise, ans) line in the traceback). In particular, it looks like a Python callable is being returned from a jitted function.

We never intended to support returning non-jaxtype values (indeed the jit docstring says the decorated function's "arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof."). But before 0.2.0 used to "work" accidentally.

I say "work" in quotes because:

  1. any returned non-jaxtype values had to be constants, i.e. could not depend on the values of the jitted function arguments, and
  2. if those returned non-jaxtype values contained any JAX tracers, e.g. by being functions with tracers in their closure or being other kinds of non-pytree containers, those leaked tracers would cause mysterious and opaque errors downstream.

In 0.2.0 we stopped supporting that never-intended and opaque-bug-prone behavior.

However, due to reason (1) above, in any case where this used to work, it shouldn't be too hard to revise the code not to return the function-valued arguments, since they were constants anyway and so didn't need to be returned from the jitted function.

What do you think?

@hawkinsp
Copy link
Collaborator

Closing due to no response. Feel free to reopen if you want to add something!

@rlouf
Copy link
Author

rlouf commented Nov 28, 2020

Thank you for your thorough explanation @mattjj ! There should be a section in the doc with all your answers in the issue tracker :)

The issue was a quick fix indeed and everything is working now!

@Joy-Preetha
Copy link

TypeError: Argument 'cpu:0' of type <class 'jaxlib.xla_extension.CpuDevice'> is not a valid JAX type.

I encountered this error while trying to run the below code.

def match_faces3(desc):
#print(descriptors)
#print(database.shape)
distances = np.empty((len(desc), len(database)))
#print("Descriptors ", descriptors)
f.write("Descriptors")
f.write(str(desc))
time1 = time.time()
for i, descr in enumerate(desc):
for j, identity in enumerate(database):
dist = []
for k, id_desc in enumerate(identity[1]):
dist.append(cosine_dist(descr, id_desc))
distances[i][j] = dist[jnp.argmin(jnp.asarray(dist))]
time2 = time.time() - time1
print("time2",time2)
distances=distances.tolist()
return jnp.asarray(distances)

jax.device_put(descriptors,jax.devices()[0])
matches_jit = jit(match_faces3)
%timeit matches_jit(jax.devices()[0])

I tried everything but couldn't figure out the problem. Please help @mattjj

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

4 participants