Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[distributed] NCCL Backend doesn't support torch.bool data type #24137

Closed
apsdehal opened this issue Aug 10, 2019 · 9 comments
Closed

[distributed] NCCL Backend doesn't support torch.bool data type #24137

apsdehal opened this issue Aug 10, 2019 · 9 comments
Assignees
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: boolean tensor oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@apsdehal
Copy link

apsdehal commented Aug 10, 2019

🐛 Bug

In version 1.2.0, NCCL backend doesn't support torch.bool datatype. Broadcasting a tensor of this type throws error "RuntimeError: Unsupported data type for NCCL process group".

To Reproduce

Steps to reproduce the behavior:

Create a file test.py with following contents:

import torch
import argparse
from torch import distributed as dist


parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)

args = parser.parse_args()

torch.distributed.init_process_group("nccl")

local_rank = args.local_rank

device = torch.device(local_rank)

if local_rank == 0:
    element = False
else:
    element = True


def broadcast_scalar(scalar, src=0, device="cpu"):
    scalar_tensor = torch.tensor(scalar).to(device)
    with torch.no_grad():
        scalar_tensor = dist.broadcast(scalar_tensor, src)
    return scalar_tensor.item()


broadcast_scalar(element, src=0, device=device)

Run it with following command:
python -u -m torch.distributed.launch --nproc_per_node 2 test.py

This has been tested on 2 GPUs.

Expected behavior

NCCL backend should support the bool datatype.
Current workaround fix: Change the datatype to int by doing .long() before broadcasting.

Environment

Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.3
[pip] torch==1.2.0
[pip] torchtext==0.3.1
[pip] torchvision==0.2.2
[conda] torch                     1.2.0                     <pip>
[conda] torchtext                 0.3.1                     <pip>
[conda] torchvision               0.2.2                     <pip>
@pytorchbot pytorchbot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 10, 2019
@mrshenli mrshenli added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: boolean tensor triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 12, 2019
@williamFalcon
Copy link
Contributor

And you get a warning saying .uint8 isn't supported so we should switch to .bool... but then NCCL doesn't support it.

@soumith
Copy link
Member

soumith commented Aug 28, 2019

cc: @mrshenli @pietern we should fix this with upcast + transfer + downcast

@rohan-varma rohan-varma self-assigned this Sep 19, 2019
@rohan-varma
Copy link
Member

@soumith, I'm working on this bug. Could you explain what you mean by "upcast + transfer + downcast"? It appears that the error is coming from here: https://github.com/pytorch/pytorch/blob/master/torch/lib/c10d/ProcessGroupNCCL.cpp#L45-L60, and seems to happen because ncclDataType_t doesn't have a bool type

@soumith
Copy link
Member

soumith commented Sep 20, 2019

@rohan-varma i mean that we should cast the buffer from bool to uint8, then all reduce, then cast it on the other side to bool again

@pietern
Copy link
Contributor

pietern commented Sep 20, 2019

Things to keep in mind:

  • Can only use uint8_t with up to 255 processes in the process group. We rely on every process contributing an integer equal to 1 if the equivalent boolean entry is set. With 256 processes we would overflow an 8-bit unsigned integer and get the wrong result. The change should either 1) assert that the process group size is small enough, or implement a separate code path that uses a 16-bit unsigned integer for larger process groups.
  • Semantics of the different reduction ops (each of which could use SUM as the underlying reduction):
    • ReduceOp.SUM -- boolean OR, so the boolean output is output != 0
    • ReduceOp.PRODUCT -- boolean AND, so the boolean output is output == pg->size
    • ReduceOp.MIN -- boolean AND (see above)
    • ReduceOp.MAX -- boolean OR (see above)

@ekmb
Copy link

ekmb commented Dec 3, 2019

@rohan-varma Is it fixed?

@mrcslws
Copy link

mrcslws commented Jul 10, 2020

I would be really happy to see this fixed. We have lots of modules that store masks in buffers. Because of this issue, these modules are forced to use float16 or float32 mask buffers rather than bool buffers.

rohan-varma added a commit that referenced this issue Jul 12, 2020
Closes #24137. Since bool is
not supported as a native ncclDataType_t, we add some upcasting + downcasting
logic to support it.

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this issue Jul 12, 2020
Closes #24137. Since bool is
not supported as a native ncclDataType_t, we add some upcasting + downcasting
logic to support it.

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

ghstack-source-id: 107598033
Pull Request resolved: #41318
rohan-varma added a commit that referenced this issue Jul 13, 2020
Closes #24137. Since bool is
not supported as a native ncclDataType_t, we add some upcasting + downcasting
logic to support it.

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this issue Jul 13, 2020
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this issue Jul 13, 2020
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this issue Jul 13, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors.
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 107675254

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
rohan-varma added a commit that referenced this issue Jul 14, 2020
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this issue Jul 14, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors.
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 107698101

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
rohan-varma added a commit that referenced this issue Jul 16, 2020
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this issue Jul 16, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors.
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 107942247

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
rohan-varma added a commit that referenced this issue Jul 17, 2020
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this issue Jul 17, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors.
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 108017010

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
rohan-varma added a commit that referenced this issue Jul 21, 2020
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Map `at::kBool` to `ncclUint8`
2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this issue Jul 21, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Map `at::kBool` to `ncclUint8`
2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 108185942

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
rohan-varma added a commit that referenced this issue Jul 23, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Map `at::kBool` to `ncclUint8`
2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 108315417

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
rohan-varma added a commit that referenced this issue Jul 23, 2020
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Map `at::kBool` to `ncclUint8`
2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
@rohan-varma rohan-varma reopened this Jul 23, 2020
@sajadn
Copy link

sajadn commented Aug 28, 2020

On pytorch 1.5.1 I still have problem with this. Is it going to be fixed?

@rohan-varma
Copy link
Member

Hi @sajadn, this was landed ~1 month ago so it should be part of the next release, PT 1.7. Until then, you can try out the nightly build (see instructions at https://pytorch.org/) where this is fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: boolean tensor oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants