-
Notifications
You must be signed in to change notification settings - Fork 23.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DLPack no longer works on Boolean tensors after 1.10+ #67081
Comments
So I think DLPack doesn't support boolean tensors and the workaround for this is to first convert a tensor to uint8 before exporting it. @Barclayll would that work for you? If I'm correct, the alternative would be that PyTorch does this bool -> uint8 conversion for the caller, but that may hide this unexpected behavior and lead to confusion when users reimported the same tensor and it was now in uint8. |
Yes indeed, see dmlc/dlpack#75 for discussion.
This implies it did work before - what was it doing there exactly? Did it roundtrip correctly? What happened when you'd read in such an exported bool tensor with CuPy or JAX? |
Historically (as of PyTorch 1.9) PyTorch would write a uint8 tensor like this:
So no, it wouldn't round-trip. It was equivalent to the proposed workaround (first cast the tensor to uint8). |
Then I would say that raising an exception as in 1.10 is the desired behavior.
This would also become an issue if DLPack resolves the issue I linked above and implements a bool dtype. That would likely force a backwards-incompatible change in the PyTorch implemetation. |
I agree with your thinking (as usual), @rgommers. Unfortunately we did make an unexpected BC-breaking change by clarifying this DLPack behavior, and we're sorry that's so disruptive @BarclayII. It is probably the "right" change from a UX perspective, however. What are your thoughts, @BarclayII? |
I understand. As per DGL we have temporarily worked around it so this is no longer a major blocker. One further question though. Since PyTorch is deprecating using ByteTensors to index into an array, I'm wondering if PyTorch has any plan on removing this support in the next release? If so, then I think it's better to enable boolean type support in DLPack. |
Glad to hear it!
Removing support for what, exactly? DLPack has an issue for this already that you may want to comment on, too: dmlc/dlpack#75. |
I meant removing the support of using ByteTensors for masked indexing. |
Aha, I don't think anyone is actively working on removing that support. That's a good connection to make, however. |
🐛 Bug
torch.utils.dlpack.to_dlpack
no longer works for Boolean Tensor.To Reproduce
Expected behavior
Should work.
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
conda
,pip
, source): pipAdditional context
According to dmlc/dgl#3406 the error is raised after #57110. I was wondering what the reason was behind removing boolean support?
This is blocking DGL's patch release with PyTorch 1.10 since we rely on DLPack to interact with PyTorch tensors in multiple places.
+@jermainewang @VoVAllen
cc @brianjo @mruberry @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser
The text was updated successfully, but these errors were encountered: