Understanding device placement behaviour of pickled DeviceArray #5882
-
If we pickle a python object that is a If the original tensor is on GPU:0 of the original machine and the new machine also has a GPU is it guaranteed that the loaded tensor will be on the new machines GPU:0? If the original tensor is on GPU:1 of the original multi-gpu machine, and the new machine has multiple GPUs, are there any guarantees on which GPU (if any) the loaded tensor will be placed on? If the original tensor is on CPU:0 of the original machine (which may or may not also have one or more GPUs), and the new machine has one or more GPUs is there any guarantee about whether the loaded tensor will be on a CPU or GPU device? This question also effects how JAX behaves w.r.t. using MPI4PY to communicate tensors over the network, as MPI4PY uses the Thank you for any help. I can experiment empirically, but not for every device and hardware configuration so understanding if and what any guaranteed behaviour is would be really helpful. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Pickle is currently only partially supported in JAX. At the moment, JAX pickles arrays by forwarding the value to numpy's pickling mechanism: https://github.com/google/jax/blob/037cc2b8213dac63ed27c91c30139ad0caf0580a/jax/interpreters/xla.py#L1264-L1265 import jax.numpy as jnp
import pickle
serialized = pickle.dumps(jnp.arange(4))
arr = pickle.loads(serialized)
print(type(arr))
# <class 'numpy.ndarray'> At this point, if used in a JAX function they will be pushed to the default device, or you can manually push to a device of your choosing using |
Beta Was this translation helpful? Give feedback.
Pickle is currently only partially supported in JAX. At the moment, JAX pickles arrays by forwarding the value to numpy's pickling mechanism: https://github.com/google/jax/blob/037cc2b8213dac63ed27c91c30139ad0caf0580a/jax/interpreters/xla.py#L1264-L1265
This means that unpickled objects will be numpy arrays rather than JAX arrays:
At this point, if used in a JAX function they will be pushed to the default device, or you can manually push to a device of your choosing using
jax.device_put
.