Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Boolean DeviceArray to DLPack #4719

Closed
yuanqing-wang opened this issue Oct 27, 2020 · 4 comments
Closed

Boolean DeviceArray to DLPack #4719

yuanqing-wang opened this issue Oct 27, 2020 · 4 comments
Assignees
Labels
question Questions for the JAX team

Comments

@yuanqing-wang
Copy link

Hi it seems that Boolean DeviceArray cannot be converted to DLPack?

>>> import jax
>>> from jax import numpy as jnp
>>> x = jnp.array([True, False])
>>> from jax import dlpack
>>> dlpack.to_dlpack(x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/wangy1/Documents/GitHub/jax/jax/_src/dlpack.py", line 42, in to_dlpack
    buf, take_ownership=take_ownership)
RuntimeError: Unimplemented: XLA type PRED has no DLPack equivalent

@hawkinsp
Copy link
Collaborator

That's because DLPack has no well-defined way to represent "bool"s:

https://github.com/dmlc/dlpack/blob/1b794e7088b754f4b0398d211452de0ab28312b5/include/dlpack/dlpack.h#L77

I could choose an encoding arbitrarily (e.g., as int8s), but that may or may not match any other system you want to share data with, and you could also achieve that by first converting to int8 and exporting that array instead.

I'm happy to add a DLPack bool export if we can agree on form it should take...

@hawkinsp hawkinsp added the question Questions for the JAX team label Oct 27, 2020
@yuanqing-wang
Copy link
Author

Looks like pytorch does convert Booleans to int8s.

>>> import torch
>>> from torch.utils import dlpack
>>> dlpack.from_dlpack(dlpack.to_dlpack(torch.tensor([True, False])))
tensor([1, 0], dtype=torch.uint8)

@hawkinsp
Copy link
Collaborator

hawkinsp commented Jan 6, 2021

My temptation would be not to do this cast automatically: in some sense, it seems better to me to be explicit and have the user cast to uint8 if that's what they want. Would it help if we improved the error message a bit to suggest this explicitly?

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jan 13, 2021
There's no corresponding import functionality, because DLPack doesn't have a representation for booleans.

Fixes jax-ml/jax#4719

PiperOrigin-RevId: 351617946
Change-Id: Ib6244be6f72c272a02d44e2e30f44d76e16bd7a7
@hawkinsp
Copy link
Collaborator

This is now fixed at head. We allow the export of bool arrays as uint8, same as PyTorch. You can't import bool arrays (since there is no such thing as a DLPack bool array at the moment).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

2 participants