Skip to content

Understanding device placement behaviour of pickled DeviceArray #5882

Answered by jakevdp
JossWhittle asked this question in Q&A
Discussion options

You must be logged in to vote

Pickle is currently only partially supported in JAX. At the moment, JAX pickles arrays by forwarding the value to numpy's pickling mechanism: https://github.com/google/jax/blob/037cc2b8213dac63ed27c91c30139ad0caf0580a/jax/interpreters/xla.py#L1264-L1265
This means that unpickled objects will be numpy arrays rather than JAX arrays:

import jax.numpy as jnp
import pickle

serialized = pickle.dumps(jnp.arange(4))
arr = pickle.loads(serialized)
print(type(arr))
# <class 'numpy.ndarray'>

At this point, if used in a JAX function they will be pushed to the default device, or you can manually push to a device of your choosing using jax.device_put.

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@JossWhittle
Comment options

@jakevdp
Comment options

@JossWhittle
Comment options

Answer selected by JossWhittle
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants