Skip to content

Commit

Permalink
add all_to_all phi operator (#54797)
Browse files Browse the repository at this point in the history
* add all_to_all phi operator, kernel, api

* add all_to_all ut

* tinyfix
  • Loading branch information
wentaoyu authored Jun 27, 2023
1 parent 7028845 commit 158b7ae
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 1 deletion.
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@
func : all_reduce
param: [x, reduce_type]

- op : all_to_all
args : (Tensor x, int ring_id = 0)
output : Tensor(out)
infer_meta :
func : AllToAllInferMeta
param: [x]
kernel :
func : all_to_all
param: [x]

- op : amax
args : (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1)
output : Tensor(out)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ void AllReduceInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(x.dims());
}

void AllToAllInferMeta(const MetaTensor& x, MetaTensor* out) {
auto dim = x.dims();
if (dim[0] < 0) dim[0] = -1;
out->set_dtype(x.dtype());
out->set_dims(dim);
}

void ArgMinMaxInferMeta(const MetaTensor& x,
const Scalar& axis,
bool keepdims,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ void AllGatherInferMeta(const MetaTensor& x, int nranks, MetaTensor* out);

void AllReduceInferMeta(const MetaTensor& x, MetaTensor* out);

void AllToAllInferMeta(const MetaTensor& x, MetaTensor* out);

void ArgMinMaxInferMeta(const MetaTensor& x,
const Scalar& axis,
bool keepdims,
Expand Down
26 changes: 26 additions & 0 deletions paddle/phi/kernels/all_to_all_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void AllToAllKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);

} // namespace phi
43 changes: 43 additions & 0 deletions paddle/phi/kernels/cpu/all_to_all_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/all_to_all_kernel.h"

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void AllToAllKernel(const Context& dev_ctx UNUSED,
const DenseTensor& x UNUSED,
DenseTensor* out UNUSED) {
PADDLE_THROW(
errors::Unimplemented("Unimplemented cpu kernel for all_to_all."));
}

} // namespace phi

PD_REGISTER_KERNEL(all_to_all,
CPU,
ALL_LAYOUT,
phi::AllToAllKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
114 changes: 114 additions & 0 deletions paddle/phi/kernels/gpu/all_to_all_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/all_to_all_kernel.h"
#include "glog/logging.h"

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/distributed/utils.h"
#endif

namespace phi {

template <typename T, typename Context>
void AllToAllKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto x_dims = x.dims();
out->Resize(x_dims);
dev_ctx.template Alloc<T>(out);

auto comm_ctx =
static_cast<distributed::NCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
errors::Unavailable("NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
gpuStream_t stream = dev_ctx.stream();
PADDLE_ENFORCE_NOT_NULL(stream,
errors::NotFound("Should initialize NCCL firstly."));

int nranks = comm_ctx->GetSize();
int send_numel = x.numel() / nranks;
size_t offset = 0;

PADDLE_ENFORCE_EQ(
x_dims[0] % nranks,
0,
errors::InvalidArgument(
"The first dimension size (%d) of the input tensor must be "
"divisible by the number of ranks (%d).",
x_dims[0],
nranks));

comm_ctx->GroupStart();

const auto* send_buf = x.data<T>();
auto* recv_buf = out->data<T>();
for (auto i = 0; i < nranks; ++i) {
auto send_buf = phi::distributed::GetPartialTensor(x, offset, send_numel);
comm_ctx->Send(send_buf, send_numel, i, stream);
auto recv_buf =
phi::distributed::GetPartialTensor(*out, offset, send_numel);
comm_ctx->Recv(&recv_buf, send_numel, i, stream);
offset += send_numel;
}
comm_ctx->GroupEnd();
#else
PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
#endif
#else
PADDLE_THROW(
errors::PreconditionNotMet("PaddlePaddle should compile with GPU."));
#endif
}

} // namespace phi

#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
PD_REGISTER_KERNEL(all_to_all,
GPU,
ALL_LAYOUT,
phi::AllToAllKernel,
float,
double,
int,
int8_t,
uint8_t,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(all_to_all,
GPU,
ALL_LAYOUT,
phi::AllToAllKernel,
float,
double,
int,
int8_t,
uint8_t,
int64_t,
bool,
phi::dtype::float16) {}
#endif
90 changes: 89 additions & 1 deletion test/collective/collective_alltoall_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,86 @@
)

import paddle
from paddle import fluid
import paddle.distributed as dist
from paddle import fluid, framework
from paddle.fluid import data_feeder

paddle.enable_static()


def alltoall_new(
in_tensor_or_tensor_list,
out_tensor_or_tensor_list,
group=None,
sync_op=True,
):
op_type = 'all_to_all'

ring_id = 0 if group is None else group.id
nranks = dist.get_world_size()
helper = framework.LayerHelper(op_type, **locals())

in_tensor = in_tensor_or_tensor_list
if isinstance(in_tensor_or_tensor_list, list):
if len(in_tensor_or_tensor_list) == 0:
raise RuntimeError("The input tensor_list should not be empty.")
# 0-D use stack/unstack while others use concat/split
if len(in_tensor_or_tensor_list[0].shape) == 0:
in_tensor = paddle.stack(in_tensor_or_tensor_list, axis=0)
else:
in_tensor = paddle.concat(in_tensor_or_tensor_list, axis=0)

out_tensor = out_tensor_or_tensor_list
if isinstance(out_tensor_or_tensor_list, list):
if len(out_tensor_or_tensor_list) != 0:
raise ValueError(
"The 'out_tensor_list' for all_to_all " "must be an empty list."
)
out_tensor = helper.create_variable_for_type_inference(
dtype=in_tensor.dtype
)

data_feeder.check_variable_and_dtype(
in_tensor,
'in_tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'int8',
'uint8',
'bool',
'uint16',
],
'all_to_all',
)
helper.append_op(
type=op_type,
inputs={'x': [in_tensor]},
outputs={'out': [out_tensor]},
attrs={
'ring_id': ring_id,
},
)
# NOTE(liyurui): If the argument `out_tensor_or_tensor_list` is a tensor_list,
# we need to split the result. So we should wait the result of all_to_all
# before split if the communication is not on calc stream.
if isinstance(out_tensor_or_tensor_list, list):
if not sync_op:
dist.wait(out_tensor, use_calc_stream=False)
# 0-D use stack/unstack while others use concat/split
if len(in_tensor_or_tensor_list[0].shape) == 0:
out_tensor_or_tensor_list.extend(paddle.unstack(out_tensor, 0))
else:
out_tensor_or_tensor_list.extend(
paddle.split(out_tensor, nranks, 0)
)

return None


class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
Expand All @@ -38,6 +113,19 @@ def get_model(self, main_prog, startup_program, rank):
paddle.distributed.alltoall(tindata, tout_data)
return tout_data

def get_model_new(
self, main_prog, startup_program, rank, dtype=None, reduce_type=None
):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[-1, 10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
tindata = paddle.split(tindata, 2, axis=0)
tout_data = []
alltoall_new(tindata, tout_data)
return tout_data


if __name__ == "__main__":
runtime_main(TestCollectiveAllToAllAPI, "alltoall")
15 changes: 15 additions & 0 deletions test/collective/test_collective_alltoall_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@ def _setup_config(self):
def test_alltoall_nccl(self):
self.check_with_place("collective_alltoall_api.py", "alltoall", "nccl")

def test_alltoall_nccl_with_comm_context(self):
dtypes_to_test = [
"float32",
]
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
"collective_alltoall_api.py",
"alltoall",
"nccl",
dtype=dtype,
need_envs={"USE_COMM_CONTEXT": "1"},
)

def test_alltoall_nccl_dygraph(self):
dtypes_to_test = [
"float16",
Expand Down

0 comments on commit 158b7ae

Please sign in to comment.