-
Notifications
You must be signed in to change notification settings - Fork 23.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[distributed] NCCL Backend doesn't support torch.bool data type #24137
Comments
And you get a warning saying .uint8 isn't supported so we should switch to .bool... but then NCCL doesn't support it. |
@soumith, I'm working on this bug. Could you explain what you mean by "upcast + transfer + downcast"? It appears that the error is coming from here: https://github.com/pytorch/pytorch/blob/master/torch/lib/c10d/ProcessGroupNCCL.cpp#L45-L60, and seems to happen because |
@rohan-varma i mean that we should cast the buffer from bool to uint8, then all reduce, then cast it on the other side to bool again |
Things to keep in mind:
|
@rohan-varma Is it fixed? |
I would be really happy to see this fixed. We have lots of modules that store masks in buffers. Because of this issue, these modules are forced to use |
Closes #24137. Since bool is not supported as a native ncclDataType_t, we add some upcasting + downcasting logic to support it. Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)! [ghstack-poisoned]
Closes #24137. Since bool is not supported as a native ncclDataType_t, we add some upcasting + downcasting logic to support it. Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)! ghstack-source-id: 107598033 Pull Request resolved: #41318
Closes #24137. Since bool is not supported as a native ncclDataType_t, we add some upcasting + downcasting logic to support it. Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)! [ghstack-poisoned]
Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 2) Run the specified reduction. 3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362 Tests are added to ensure that the reductions work as expected. Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)! [ghstack-poisoned]
Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 2) Run the specified reduction. 3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362 Tests are added to ensure that the reductions work as expected. Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)! [ghstack-poisoned]
Pull Request resolved: #41318 Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 2) Run the specified reduction. 3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Tests are added to ensure that the reductions work as expected. ghstack-source-id: 107675254 Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 2) Run the specified reduction. 3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362 Tests are added to ensure that the reductions work as expected. Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)! [ghstack-poisoned]
Pull Request resolved: #41318 Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 2) Run the specified reduction. 3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Tests are added to ensure that the reductions work as expected. ghstack-source-id: 107698101 Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 2) Run the specified reduction. 3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362 Tests are added to ensure that the reductions work as expected. Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)! [ghstack-poisoned]
Pull Request resolved: #41318 Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 2) Run the specified reduction. 3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Tests are added to ensure that the reductions work as expected. ghstack-source-id: 107942247 Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 2) Run the specified reduction. 3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362 Tests are added to ensure that the reductions work as expected. Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)! [ghstack-poisoned]
Pull Request resolved: #41318 Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 2) Run the specified reduction. 3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Tests are added to ensure that the reductions work as expected. ghstack-source-id: 108017010 Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Map `at::kBool` to `ncclUint8` 2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362 Tests are added to ensure that the reductions work as expected. Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)! [ghstack-poisoned]
Pull Request resolved: #41318 Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Map `at::kBool` to `ncclUint8` 2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Tests are added to ensure that the reductions work as expected. ghstack-source-id: 108185942 Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
Pull Request resolved: #41318 Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Map `at::kBool` to `ncclUint8` 2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Tests are added to ensure that the reductions work as expected. ghstack-source-id: 108315417 Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
Closes #24137. This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic: 1) Map `at::kBool` to `ncclUint8` 2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR. The reduction logic (for example for reduce/allreduce) is as follows: sum, max = bitwise or product, min = bitwise and Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362 Tests are added to ensure that the reductions work as expected. Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)! [ghstack-poisoned]
On pytorch 1.5.1 I still have problem with this. Is it going to be fixed? |
Hi @sajadn, this was landed ~1 month ago so it should be part of the next release, PT 1.7. Until then, you can try out the nightly build (see instructions at https://pytorch.org/) where this is fixed. |
🐛 Bug
In version 1.2.0, NCCL backend doesn't support
torch.bool
datatype. Broadcasting a tensor of this type throws error "RuntimeError: Unsupported data type for NCCL process group".To Reproduce
Steps to reproduce the behavior:
Create a file test.py with following contents:
Run it with following command:
python -u -m torch.distributed.launch --nproc_per_node 2 test.py
This has been tested on 2 GPUs.
Expected behavior
NCCL backend should support the bool datatype.
Current workaround fix: Change the datatype to int by doing
.long()
before broadcasting.Environment
The text was updated successfully, but these errors were encountered: