From 0a508e04d67028e8f01011366ec51b11fd85c362 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 30 Oct 2023 11:08:58 +0000 Subject: [PATCH 01/23] custom op support auto parallel --- paddle/phi/api/lib/op_meta_info.cc | 27 +- test/auto_parallel/dist_custom_relu_op.cc | 268 ++++++++++++++++++ test/auto_parallel/dist_custom_relu_op.cu | 166 +++++++++++ test/auto_parallel/dist_custom_relu_op_dup.cc | 38 +++ .../semi_auto_parallel_for_custom_relu.py | 123 ++++++++ 5 files changed, 615 insertions(+), 7 deletions(-) create mode 100644 test/auto_parallel/dist_custom_relu_op.cc create mode 100644 test/auto_parallel/dist_custom_relu_op.cu create mode 100644 test/auto_parallel/dist_custom_relu_op_dup.cc create mode 100644 test/auto_parallel/semi_auto_parallel_for_custom_relu.py diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index da8b9125a71ddd..bf6bb65c625adf 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/enforce.h" namespace paddle { @@ -63,10 +64,12 @@ PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) { "happens when handling inplace optional inputs & outputs."; return; } - PADDLE_ENFORCE_EQ(src.is_dense_tensor() && dst->is_dense_tensor(), - true, - phi::errors::Unavailable( - "Now only supported DenseTensor in Custom Operator.")); + PADDLE_ENFORCE_EQ( + ((src.is_dense_tensor() && dst->is_dense_tensor()) || + (src.is_dist_tensor() && dst->is_dist_tensor())), + true, + phi::errors::Unavailable( + "Now only supported DenseTensor and DistTensor in Custom Operator.")); PADDLE_ENFORCE_EQ( src.initialized(), true, @@ -76,9 +79,19 @@ PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) { true, phi::errors::Unavailable( "The Custom OpKernel origin output is not defined.")); - auto& dense_src = static_cast(*src.impl()); - auto* dense_dst = static_cast(dst->impl().get()); - *dense_dst = dense_src; + if (src.is_dense_tensor()) { + auto& dense_src = static_cast(*src.impl()); + auto* dense_dst = static_cast(dst->impl().get()); + *dense_dst = dense_src; + } else { + auto* dense_src = + static_cast(src.impl().get()) + ->unsafe_mutable_value(); + auto* dense_dst = + static_cast(dst->impl().get()) + ->unsafe_mutable_value(); + *dense_dst = *dense_src; + } } ////////////////////// Kernel Context ////////////////////// diff --git a/test/auto_parallel/dist_custom_relu_op.cc b/test/auto_parallel/dist_custom_relu_op.cc new file mode 100644 index 00000000000000..5627bb28b921f4 --- /dev/null +++ b/test/auto_parallel/dist_custom_relu_op.cc @@ -0,0 +1,268 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") + +template +void relu_cpu_forward_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + PD_CHECK(x_data != nullptr, "x_data is nullptr."); + PD_CHECK(out_data != nullptr, "out_data is nullptr."); + for (int64_t i = 0; i < x_numel; ++i) { + out_data[i] = std::max(static_cast(0.), x_data[i]); + } +} + +template +void relu_cpu_backward_kernel(const data_t* grad_out_data, + const data_t* out_data, + data_t* grad_x_data, + int64_t out_numel) { + for (int64_t i = 0; i < out_numel; ++i) { + grad_x_data[i] = + grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); + } +} + +template +void relu_cpu_double_backward_kernel(const data_t* out_data, + const data_t* ddx_data, + data_t* ddout_data, + int64_t ddout_numel) { + for (int64_t i = 0; i < ddout_numel; ++i) { + ddout_data[i] = + ddx_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); + } +} + +std::vector relu_cpu_forward(const paddle::Tensor& x) { + auto out = paddle::empty_like(x); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "relu_cpu_forward", ([&] { + relu_cpu_forward_kernel( + x.data(), out.data(), x.numel()); + })); + + return {out}; +} + +std::vector relu_cpu_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + auto grad_x = paddle::empty_like(x); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { + relu_cpu_backward_kernel( + grad_out.data(), + out.data(), + grad_x.data(), + out.size()); + })); + + return {grad_x}; +} + +std::vector relu_cpu_double_backward( + const paddle::Tensor& out, const paddle::Tensor& ddx) { + CHECK_CPU_INPUT(out); + CHECK_CPU_INPUT(ddx); + auto ddout = paddle::empty(out.shape(), out.dtype(), out.place()); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_double_backward", ([&] { + relu_cpu_double_backward_kernel( + out.data(), + ddx.data(), + ddout.mutable_data(out.place()), + ddout.size()); + })); + + return {ddout}; +} + +std::vector relu_cuda_forward(const paddle::Tensor& x); +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); +std::vector relu_cuda_double_backward( + const paddle::Tensor& out, const paddle::Tensor& ddx); + +std::vector ReluForward(const paddle::Tensor& x) { + if (x.is_cpu()) { + return relu_cpu_forward(x); + } else if (x.is_gpu()) { + return relu_cuda_forward(x); + } else { + PD_THROW("Not implemented."); + } +} + +std::vector ReluBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + if (x.is_cpu()) { + return relu_cpu_backward(x, out, grad_out); + } else if (x.is_gpu()) { + return relu_cuda_backward(x, out, grad_out); + } else { + PD_THROW("Not implemented."); + } +} + +std::vector ReluDoubleBackward(const paddle::Tensor& out, + const paddle::Tensor& ddx) { + if (out.is_cpu()) { + return relu_cpu_double_backward(out, ddx); + } else if (out.is_gpu()) { + return relu_cuda_double_backward(out, ddx); + } else { + PD_THROW("Not implemented."); + } +} + +std::vector> ReluDoubleBackwardInferShape( + const std::vector& out_shape, + const std::vector& ddx_shape) { + return {out_shape}; +} + +PD_BUILD_OP(custom_relu) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)); + +PD_BUILD_GRAD_OP(custom_relu) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackward)); + +PD_BUILD_DOUBLE_GRAD_OP(custom_relu) + .Inputs({"Out", paddle::Grad(paddle::Grad("X"))}) + .Outputs({paddle::Grad(paddle::Grad("Out"))}) + .SetKernelFn(PD_KERNEL(ReluDoubleBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(ReluDoubleBackwardInferShape)); + +std::vector relu_cpu_backward_without_x( + const paddle::Tensor& out, const paddle::Tensor& grad_out) { + auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place()); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { + relu_cpu_backward_kernel( + grad_out.data(), + out.data(), + grad_x.mutable_data(out.place()), + out.size()); + })); + + return {grad_x}; +} + +std::vector relu_cuda_backward_without_x( + const paddle::Tensor& out, const paddle::Tensor& grad_out); + +std::vector ReluBackwardWithoutX( + const paddle::Tensor& out, const paddle::Tensor& grad_out) { + if (out.is_cpu()) { + return relu_cpu_backward_without_x(out, grad_out); + } else if (out.is_gpu()) { + return relu_cuda_backward_without_x(out, grad_out); + } else { + PD_THROW("Not implemented."); + } +} + +std::vector> ReluBackwardWithoutXInferShape( + const std::vector& out_shape, + const std::vector& grad_out_shape) { + return {out_shape}; +} + +PD_BUILD_OP(custom_relu_no_x_in_backward) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)); + +PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward) + .Inputs({"Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackwardWithoutX)) + .SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape)); + +void relu_cpu_forward_out(const paddle::Tensor& x, paddle::Tensor* out) { + out->reshape(x.shape()); + PD_DISPATCH_FLOATING_TYPES( + x.type(), "relu_cpu_forward", ([&] { + relu_cpu_forward_kernel( + x.data(), out->mutable_data(x.place()), x.numel()); + })); +} + +void relu_cpu_backward_out(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out, + paddle::Tensor* grad_x) { + grad_x->reshape(x.shape()); + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { + relu_cpu_backward_kernel( + grad_out.data(), + out.data(), + grad_x->mutable_data(x.place()), + out.size()); + })); +} + +void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out); +void relu_cuda_backward_out(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out, + paddle::Tensor* grad_x); + +void ReluForwardOut(const paddle::Tensor& x, paddle::Tensor* out) { + if (x.is_cpu()) { + return relu_cpu_forward_out(x, out); + } else if (x.is_gpu()) { + return relu_cuda_forward_out(x, out); + } else { + PD_THROW("Not implemented."); + } +} + +void ReluBackwardOut(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out, + paddle::Tensor* grad_x) { + if (x.is_cpu()) { + return relu_cpu_backward_out(x, out, grad_out, grad_x); + } else if (x.is_gpu()) { + return relu_cuda_backward_out(x, out, grad_out, grad_x); + } else { + PD_THROW("Not implemented."); + } +} + +PD_BUILD_OP(custom_relu_out) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForwardOut)); + +PD_BUILD_GRAD_OP(custom_relu_out) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackwardOut)); diff --git a/test/auto_parallel/dist_custom_relu_op.cu b/test/auto_parallel/dist_custom_relu_op.cu new file mode 100644 index 00000000000000..49e5d16938eb80 --- /dev/null +++ b/test/auto_parallel/dist_custom_relu_op.cu @@ -0,0 +1,166 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +#define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") + +template +__global__ void relu_cuda_forward_kernel(const data_t* x, + data_t* y, + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { + y[i] = x[i] > static_cast(0.) ? x[i] : static_cast(0.); + } +} + +template +__global__ void relu_cuda_backward_kernel(const data_t* dy, + const data_t* y, + data_t* dx, + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { + dx[i] = dy[i] * (y[i] > static_cast(0.) ? static_cast(1.) + : static_cast(0.)); + } +} + +template +__global__ void relu_cuda_double_backward_kernel(const data_t* out_data, + const data_t* ddx_data, + data_t* ddout_data, + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { + ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast(0.) + ? static_cast(1.) + : static_cast(0.)); + } +} + +std::vector relu_cuda_forward(const paddle::Tensor& x) { + CHECK_GPU_INPUT(x); + auto out = paddle::empty_like(x); + + PD_CHECK(x.place() == paddle::DefaultGPUPlace()); + + int64_t numel = x.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + x.type(), "relu_cuda_forward_kernel", ([&] { + relu_cuda_forward_kernel<<>>( + x.data(), out.data(), numel); + })); + + return {out}; +} + +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + CHECK_GPU_INPUT(x); + CHECK_GPU_INPUT(out); + CHECK_GPU_INPUT(grad_out); + auto grad_x = paddle::empty_like(x); + + PD_CHECK(x.place() == paddle::DefaultGPUPlace()); + + int64_t numel = out.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + out.type(), "relu_cuda_backward_kernel", ([&] { + relu_cuda_backward_kernel<<>>( + grad_out.data(), + out.data(), + grad_x.mutable_data(x.place()), + numel); + })); + + return {grad_x}; +} + +std::vector relu_cuda_double_backward( + const paddle::Tensor& out, const paddle::Tensor& ddx) { + CHECK_GPU_INPUT(out); + CHECK_GPU_INPUT(ddx); + auto ddout = paddle::empty(out.shape(), out.dtype(), out.place()); + + int64_t numel = out.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + out.type(), "relu_cuda_double_backward_kernel", ([&] { + relu_cuda_double_backward_kernel + <<>>( + out.data(), + ddx.data(), + ddout.mutable_data(out.place()), + numel); + })); + + return {ddout}; +} + +std::vector relu_cuda_backward_without_x( + const paddle::Tensor& out, const paddle::Tensor& grad_out) { + auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place()); + + int numel = out.numel(); + int block = 512; + int grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + out.type(), "relu_cuda_backward_kernel", ([&] { + relu_cuda_backward_kernel<<>>( + grad_out.data(), + out.data(), + grad_x.mutable_data(out.place()), + numel); + })); + + return {grad_x}; +} + +void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out) { + int numel = x.numel(); + int block = 512; + int grid = (numel + block - 1) / block; + out->reshape(x.shape()); + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + x.type(), "relu_cuda_forward_kernel", ([&] { + relu_cuda_forward_kernel<<>>( + x.data(), out->mutable_data(x.place()), numel); + })); +} + +void relu_cuda_backward_out(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out, + paddle::Tensor* grad_x) { + int numel = out.numel(); + int block = 512; + int grid = (numel + block - 1) / block; + grad_x->reshape(x.shape()); + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + out.type(), "relu_cuda_backward_kernel", ([&] { + relu_cuda_backward_kernel<<>>( + grad_out.data(), + out.data(), + grad_x->mutable_data(x.place()), + numel); + })); +} diff --git a/test/auto_parallel/dist_custom_relu_op_dup.cc b/test/auto_parallel/dist_custom_relu_op_dup.cc new file mode 100644 index 00000000000000..89d14bfa049603 --- /dev/null +++ b/test/auto_parallel/dist_custom_relu_op_dup.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +std::vector relu_cuda_forward(const paddle::Tensor& x); +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); + +std::vector ReluForward(const paddle::Tensor& x); + +std::vector ReluBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); + +// Reuse codes in `custom_relu_op.cc/cu` to register another custom operator +// to test jointly compile multi operators at same time. +PD_BUILD_OP(custom_relu_dup) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)); + +PD_BUILD_GRAD_OP(custom_relu_dup) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackward)); diff --git a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py new file mode 100644 index 00000000000000..1242f98356e807 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from site import getsitepackages + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.utils.cpp_extension import get_build_directory, load +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS, run_cmd + +# Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI. +# `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find +# paddle include directory. Because the following path is generated after installing +# PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. +paddle_includes = [] +for site_packages_path in getsitepackages(): + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include') + ) + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include', 'third_party') + ) + +# Test for extra compile args +extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] +extra_nvcc_args = ['-O3'] + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = f'{get_build_directory()}\\dist_custom_relu\\dist_custom_relu.pyd' +if os.name == 'nt' and os.path.isfile(file): + cmd = f'del {file}' + run_cmd(cmd, True) + +if os.name == 'nt': + test_include = "..\\python\\paddle\\base\\tests\\auto_parallel" +else: + test_include = "../python/paddle/base/tests/auto_parallel" +paddle_includes.append(test_include) + +custom_ops = load( + name='dist_custom_relu_jit', + sources=[ + 'dist_custom_relu_op.cc', + 'dist_custom_relu_op_dup.cc', + 'dist_custom_relu_op.cu', + ], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True, +) + + +class TestCustomReluForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def test_body(self, x_shape, x_specs): + paddle.seed(self._seed) + np.random.seed(self._seed) + + x_np = np.random.random(size=x_shape).astype(self._dtype) + x = paddle.to_tensor(x_np) + x.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + + dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr) + dist_x.stop_gradient = False + + out = custom_ops.custom_relu(x) + dist_out = custom_ops.custom_relu(dist_x) + out.stop_gradient = False + dist_out.stop_gradient = False + + self.check_tensor_eq(out, dist_out) + + out.backward() + dist_out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + + def test_custom_relu(self): + self.test_body( + x_shape=[64, 32], + x_specs=['x', None], + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_custom_relu() + + +if __name__ == '__main__': + TestCustomReluForSemiAutoParallel().test_custom_relu() From ebc9369e74e0958aeedb9a28a1218452274da95d Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 31 Oct 2023 01:36:32 +0000 Subject: [PATCH 02/23] refine --- paddle/fluid/pybind/eager_functions.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index df3e62b3bae476..46499c695b99aa 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -67,6 +67,11 @@ typedef SSIZE_T ssize_t; #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/core/flags.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif + PHI_DECLARE_string(tensor_operants_mode); namespace paddle { From 429bfd8fe3b07014d0c2fa502471fb26980eb0dc Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 31 Oct 2023 11:24:40 +0000 Subject: [PATCH 03/23] refine --- paddle/fluid/pybind/eager_functions.cc | 574 ++++++++++++++++-- .../semi_auto_parallel_for_custom_relu.py | 2 + 2 files changed, 512 insertions(+), 64 deletions(-) diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 46499c695b99aa..1d9d6309341935 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -513,6 +513,512 @@ static Tensor InitializedEmptyTensor() { return tensor; } +static std::vector> RunDefaultInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + if (inplace_map.empty()) { // general case, assure single input and output + PADDLE_ENFORCE_EQ( + inputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple inputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" + "Please set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple outputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" + "Please set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + + VLOG(3) << "Custom Operator: Default InferShape - share ddim."; + result.emplace_back({ctx.InputAt(0).dims()}); + } else { // inplace case + PADDLE_ENFORCE_EQ( + inplace_map.size(), + outputs.size(), + phi::errors::Unavailable( + "Your custom operator uses `SetInplaceMap` without setting the " + "InferShapeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " + "the InferShapeFn of custom operator by " + "`.SetInferShapeFn(PD_INFER_SHAPE(...)`)", + outputs.size(), + inplace_map.size())); + for (auto const& pair : ctx.GetInplaceIndexMap()) { + if (detail::IsDuplicableVar(inputs[pair.first])) { + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(pair.first); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dims()); + } + result.emplace_back(std::move(shapes)); + } else { + auto duplicable_input_pair = ctx.InputRangeAt(pair.first); + result.emplace_back({ctx.InputAt(duplicable_input_pair.first).dims()}); + } + } + } + return result; +} + +static std::vector> RunInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const paddle::InferShapeFunc& func, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + std::vector> input_shapes; + std::vector>> vec_input_shapes; + + VLOG(3) << "Custom Operator: InferShape - get input ddim."; + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + const auto& input_pair = ctx.InputRangeAt(i); + if (input_pair.first == input_pair.second) { + input_shapes.emplace_back( + std::move(ctx.InputAt(input_pair.first).shape())); + } else { + std::vector> shapes; + for (size_t j = input_pair.first; j < input_pair.second; j++) { + shapes.push_back(std::move(ctx.InputAt(j).shape())); + } + vec_input_shapes.emplace_back(std::move(shapes)); + } + } + + VLOG(3) << "Custom Operator: InferShape - calc output ddim."; + auto output_shapes = func(input_shapes, vec_input_shapes, ctx.Attrs()); + if (inplace_map.empty()) { + PADDLE_ENFORCE_EQ(outputs.size(), + output_shapes.size(), + phi::errors::InvalidArgument( + "Your custom operator has set the InferShapeFn. " + "However, `Outputs` size = %d does not match the " + "returned vector size of InferShapeFn = %d. Please " + "check InferShapeFn again.", + outputs.size(), + output_shapes.size())); + } else { + PADDLE_ENFORCE_EQ( + outputs.size(), + output_shapes.size() + inplace_map.size(), + phi::errors::InvalidArgument( + "Your custom operator uses `SetInplaceMap` and sets the " + "InferShapeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap size + InferShapeFn output size` = %d. Please check " + "InplaceMap and InferShapeFn again", + outputs.size(), + output_shapes.size() + inplace_map.size())); + } + + VLOG(3) + << "Custom Operator: InferShape - set output ddim: inplace_map.size() = " + << inplace_map.size() + << ", output_shapes.size() = " << output_shapes.size(); + size_t output_shape_idx = 0; + auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); + for (size_t i = 0; i < outputs.size(); ++i) { + if (detail::IsDuplicableVar(outputs[i])) { + PADDLE_ENFORCE( + inplace_reverse_map.find(i) != inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Custom operator only supports `paddle::Vec(...)` inputs and " + "cannot support `paddle::Vec(...)` output without setting " + "InplaceMap. If you have to use `paddle::Vec(...)` output, " + "please indicate it by setting InplaceMap manully.")); + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dims()); + } + result.emplace_back(std::move(shapes)); + } else { + if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + result.emplace_back({ctx.InputAt(duplicable_input_pair.first).dims()}); + } else { + result.emplace_back( + {phi::make_ddim(output_shapes[output_shape_idx++])}); + } + } + } +} + +static std::vector> RunDefaultInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + if (inplace_map.empty()) { // general case, assure single input and output + PADDLE_ENFORCE_EQ( + inputs.size(), + 1UL, + platform::errors::Unavailable( + "Your custom operator contains multiple inputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" + "Please set the InferDtypeFn of custom " + "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1UL, + platform::errors::Unavailable( + "Your custom operator contains multiple outputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" + "Please set the InferDtypeFn of custom " + "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); + + VLOG(3) << "Custom Operator: InferDtype - share dtype."; + result.emplace_back({ctx.InputAt(0).dtype()}); + } else { // inplace case + PADDLE_ENFORCE_EQ( + inplace_map.size(), + outputs.size(), + phi::errors::Unavailable( + "Your custom operator uses `SetInplaceMap` without setting the " + "InferDtypeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " + "the InferDtypeFn of custom operator by " + "`.SetInferDtypeFn(PD_INFER_DTYPE(...))`", + outputs.size(), + inplace_map.size())); + for (auto const& pair : ctx.GetInplaceIndexMap()) { + if (detail::IsDuplicableVar(inputs[pair.first])) { + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(pair.first); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dtype()); + } + result.emplace_back(std::move(shapes)); + } else { + auto duplicable_input_pair = ctx.InputRangeAt(pair.first); + result.emplace_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); + } + } + } + return result; +} + +static std::vector> RunInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const paddle::InferDtypeFunc& func, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + std::vector input_dtypes; + std::vector> vec_input_dtypes; + + VLOG(3) << "Custom Operator: InferDtype - get input dtype."; + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + const auto& input_pair = ctx.InputRangeAt(i); + if (input_pair.first == input_pair.second) { + input_dtypes.emplace_back( + std::move(ctx.InputAt(input_pair.first).dtype())); + } else { + std::vector> dtypes; + for (size_t j = input_pair.first; j < input_pair.second; j++) { + dtypes.push_back(std::move(ctx.InputAt(j).dtype())); + } + vec_input_dtypes.emplace_back(std::move(dtypes)); + } + } + + VLOG(3) << "Custom Operator: InferDtype - infer output dtype."; + auto output_dtypes = func(input_dtypes, vec_input_dtypes, ctx.Attrs()); + if (inplace_map.empty()) { + PADDLE_ENFORCE_EQ(outputs.size(), + output_dtypes.size(), + phi::errors::InvalidArgument( + "Your custom operator has set the InferDtypeFn. " + "However, `Outputs` size = %d does not match the " + "returned vector size of InferDtypeFn = %d. Please " + "check InferDtypeFn again.", + outputs.size(), + output_dtypes.size())); + } else { + PADDLE_ENFORCE_EQ( + outputs.size(), + output_dtypes.size() + inplace_map.size(), + phi::errors::InvalidArgument( + "Your custom operator uses `SetInplaceMap` and sets the " + "InferDtypeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap size + InferDtypeFn output size` = %d. Please check " + "InplaceMap and InferDtypeFn again", + outputs.size(), + output_dtypes.size() + inplace_map.size())); + } + + VLOG(3) + << "Custom Operator: InferDtype - set output dtype: inplace_map.size() = " + << inplace_map.size() + << ", output_dtypes.size() = " << output_dtypes.size(); + size_t output_dtype_idx = 0; + auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); + for (size_t i = 0; i < outputs.size(); ++i) { + if (detail::IsDuplicableVar(outputs[i])) { + PADDLE_ENFORCE( + inplace_reverse_map.find(i) != inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Custom operator only supports `paddle::Vec(...)` inputs and " + "cannot support `paddle::Vec(...)` output without setting " + "InplaceMap. If you have to use `paddle::Vec(...)` output, " + "please indicate it by setting InplaceMap manully.")); + std::vector dtypes; + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + dtypes.push_back(ctx.InputAt(j).dtype()); + } + result.emplace_back(std::move(dtypes)); + } else { + if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + result.emplace_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); + } else { + result.emplace_back( + {phi::make_ddim(output_dtypes[output_dtype_idx++])}); + } + } + } +} + +phi::Tensor BuildEmptyDistPhiTensor(const ProcessMesh& process_mesh, + const phi::DDim& dims, + phi::DataType dtype) { + phi::Tensor empty_tensor; + DenseTensorMeta meta; + meta.dims = dims; + meta.dtype = dtype; + + auto dist_attr = phi::distributed::TensorDistAttr(phi::vectorize(dims)); + dist_attr.set_process_mesh(process_mesh); + + auto dist_t = std::make_shared( + std::make_shared( + std::make_shared( + nullptr, 0, phi::distributed::GetDefaultPlace()), + meta), + dist_attr); + empty_tensor.set_impl(dist_t); + return empty_tensor; +} + +void run_custom_op_kernel( + const paddle::CustomOpKernelContext& ctx, + const std::vector& vec_map, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + bool run_auto_parallel = false; + bool rank_is_in_current_mesh = true; + ProcessMesh current_process_mesh; + std::vector* all_inputs = ctx.AllMutableInput(); + +#ifdef PADDLE_WITH_DISTRIBUTE + std::vector x = *all_inputs; + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, x)) { + ConvertAllInputsToDistTensor(mesh, x); + } + + run_auto_parallel = AllInputsAreDistTensor(x); + rank_is_in_current_mesh = true; + if (run_auto_parallel) { + for (size_t i = 0; i < all_inputs.size(); ++i) { + PADDLE_ENFORCE_EQ(all_inputs->at(i).initialized() && + all_inputs->at(i).is_dense_tensor() && + all_inputs->at(i).is_gpu(), + true, + phi::errors::InvalidArgument( + "The custom op's input tensor must be initialized " + "dense tensor on gpu, in AutoParallel mode.")); + } + + auto mesh = + std::static_pointer_cast(x.at(0).impl()) + ->dist_attr() + .process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + std::vector input_x(x.size()); + for (size_t i = 0; i < input_x.size(); ++i) { + input_x[i] = x.at(i).impl().get(); + } + + auto meta_dist_input_x = MakeDistMetaTensor(input_x); + auto spmd_info = + phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x); + current_process_mesh = spmd_info.first[0].process_mesh(); + + if (rank_is_in_current_mesh) { + auto* dev_ctx = static_cast(pool.Get(x.at(0).place())); + auto dist_input_x = + ReshardApiInputToReplicatedKernelInput(dev_ctx, x, spmd_info.first); + for (size_t i = 0; i < x.size(); ++i) { + *static_cast(all_inputs->at(i).impl().get()) = + *(dist_input_x[i]->unsafe_mutable_value()); + } + } else { + auto& infer_shape_func = + paddle::OpMetaInfoHelper::GetInferShapeFn(vec_map[0]); + auto& infer_dtype_func = + paddle::OpMetaInfoHelper::GetInferDtypeFn(vec_map[0]); + + std::vector> out_dims; + if (infer_shape_func) { + out_dims = RunInferShapeFunc( + ctx, infer_shape_func, inputs, outputs, inplace_map); + } else { + out_dims = RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); + } + + std::vector> out_dtypes; + if (infer_dtype_func) { + out_dtypes = RunInferDtypeFunc( + ctx, infer_dtype_func, inputs, outputs, inplace_map); + } else { + out_dtypes = + RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); + } + + PADDLE_ENFORCE_EQ( + out_dims.size(), + out_dtypes.size(), + phi::errors::InvalidArgument("custome op infer_shape and infer_dtype " + "must have the same output size.")); + + for (size_t i = 0; i < out_dims.size(); ++i) { + const auto& out_dim = out_dims.at(i); + const auto& out_dtype = out_dtypes.at(i); + PADDLE_ENFORCE_EQ( + out_dim.size(), + out_dtype.size(), + phi::errors::InvalidArgument( + "custome op infer_shape result[%d] and infer_dtype result[%d] " + "must have the same output size.", + i, + i)); + if (out_dim.size() == 0) { + ctx.EmplaceBackOutput(std::move(paddle::Tensor())); + } else if (out_dim.size() == 1) { + ctx.EmplaceBackOutput(std::move(BuildEmptyDistPhiTensor( + current_process_mesh, out_dim[0], out_dtype[0]))); + } else { + std::vector out_tensors; + out_tensors.reverse(out_dim.size()); + for (size_t j = 0; j < out_dim.size(); ++j) { + out_tensors.emplace_back(BuildEmptyDistPhiTensor( + current_process_mesh, out_dim[j], out_dtype[j])); + } + ctx.EmplaceBackOutputs(out_tensors); + } + } + return; + } + } +#endif + + for (size_t i = 0; i < all_inputs->size(); ++i) { + auto& tensor = all_inputs->at(i); + if (tensor.initialized() && tensor.is_dense_tensor() && + !std::dynamic_pointer_cast(tensor.impl()) + ->meta() + .is_contiguous()) { + tensor.set_impl(std::make_shared( + std::move(paddle::experimental::Trans2Contiguous( + *(std::dynamic_pointer_cast(tensor.impl())))))); + } + } + + const auto& inplace_reverse_idx_map = ctx.GetInplaceReverseIndexMap(); + for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) { + const auto& output = outputs.at(out_idx); + // inplace special case + if (inplace_reverse_idx_map.find(out_idx) != + inplace_reverse_idx_map.end()) { + size_t in_idx = inplace_reverse_idx_map.at(out_idx); + const auto& input_range = ctx.InputRangeAt(in_idx); + const auto& input_tensor = ctx.InputAt(input_range.first); + // inplace optional [Tensor or vector], un-initialized tensor. + if (paddle::framework::detail::IsOptionalVar(output) && + !input_tensor.initialized()) { + VLOG(7) << "Custom operator add output " << output + << " to CustomOpKernelContext. Add un-initialized tensor " + "because the inplace optional input is None"; + ctx.EmplaceBackOutput(std::move(paddle::Tensor())); + continue; + } + /// inplace vector, initialized tensor. + if (paddle::framework::detail::IsDuplicableVar(output)) { + std::vector empty_tensors; + size_t vector_size = input_range.second - input_range.first; + empty_tensors.resize(vector_size); + for (size_t i = 0; i < vector_size; ++i) { + empty_tensors[i] = InitializedEmptyTensor(); + } + VLOG(7) << "Custom operator add output " << output + << " to CustomOpKernelContext. Add vector size = " + << empty_tensors.size(); + ctx.EmplaceBackOutputs(empty_tensors); + continue; + } + } + VLOG(7) << "Custom operator add output " << output + << " to CustomOpKernelContext. Add initialized Tensor because " + "using general or inplace mechanism"; + // general Tensor or inplace Tensor, initialized tensor. + ctx.EmplaceBackOutput(std::move(InitializedEmptyTensor())); + } + + // handle inplace map + ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); + VLOG(7) << "Run Kernel of Custom Op: " << op_type; + (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx); + ctx.AssignInplaceOutputs(); + +#ifdef PADDLE_WITH_DISTRIBUTE + if (run_auto_parallel) { + std::vector* output_all = ctx.AllMutableOutput(); + for (size_t i = 0; i < output_all->size(); ++i) { + auto& tensor = output_all->at(i); + auto dist_attr = + phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); + dist_attr.set_process_mesh(current_process_mesh); + auto dist_t = std::make_shared( + tensor.impl(), dist_attr); + tensor.set_impl(dist_t); + } + } +#endif +} + static PyObject* eager_api_run_custom_op(PyObject* self, PyObject* args, PyObject* kwargs) { @@ -540,6 +1046,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, const auto& attrs = paddle::OpMetaInfoHelper::GetAttrs(vec_map[0]); const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(vec_map[0]); const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[0]); + ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); for (size_t i = 0; i < inputs.size(); ++i) { const auto& input = inputs.at(i); // Parse op_type first, so that use i + 1 @@ -557,17 +1064,6 @@ static PyObject* eager_api_run_custom_op(PyObject* self, if (paddle::framework::detail::IsDuplicableVar(input)) { std::vector tensors = std::move(CastPyArg2VectorOfTensor(obj, i + 1)); // NOLINT - for (auto& tensor : tensors) { - if (tensor.initialized() && tensor.is_dense_tensor() && - !std::dynamic_pointer_cast(tensor.impl()) - ->meta() - .is_contiguous()) { - tensor.set_impl(std::make_shared( - std::move(paddle::experimental::Trans2Contiguous( - *(std::dynamic_pointer_cast( - tensor.impl())))))); - } - } ctx.EmplaceBackInputs(std::move(tensors)); VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add vector size = " @@ -575,19 +1071,12 @@ static PyObject* eager_api_run_custom_op(PyObject* self, } else { paddle::Tensor tensor = std::move(CastPyArg2Tensor(obj, i + 1)); // NOLINT - if (tensor.initialized() && tensor.is_dense_tensor() && - !std::dynamic_pointer_cast(tensor.impl()) - ->meta() - .is_contiguous()) { - tensor.set_impl(std::make_shared( - std::move(paddle::experimental::Trans2Contiguous(*( - std::dynamic_pointer_cast(tensor.impl())))))); - } ctx.EmplaceBackInput(std::move(tensor)); VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add Tensor for general case."; } } + // Parse op_type and inputs first, so that use 1 + inputs.size() + i int attr_start_idx = static_cast(1 + inputs.size()); for (size_t i = 0; i < attrs.size(); ++i) { @@ -633,54 +1122,11 @@ static PyObject* eager_api_run_custom_op(PyObject* self, attr_type_str)); } } + { eager_gil_scoped_release guard; - ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); - const auto& inplace_reverse_idx_map = ctx.GetInplaceReverseIndexMap(); - for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) { - const auto& output = outputs.at(out_idx); - // inplace special case - if (inplace_reverse_idx_map.find(out_idx) != - inplace_reverse_idx_map.end()) { - size_t in_idx = inplace_reverse_idx_map.at(out_idx); - const auto& input_range = ctx.InputRangeAt(in_idx); - const auto& input_tensor = ctx.InputAt(input_range.first); - // inplace optional [Tensor or vector], un-initialized tensor. - if (paddle::framework::detail::IsOptionalVar(output) && - !input_tensor.initialized()) { - VLOG(7) << "Custom operator add output " << output - << " to CustomOpKernelContext. Add un-initialized tensor " - "because the inplace optional input is None"; - ctx.EmplaceBackOutput(std::move(paddle::Tensor())); - continue; - } - /// inplace vector, initialized tensor. - if (paddle::framework::detail::IsDuplicableVar(output)) { - std::vector empty_tensors; - size_t vector_size = input_range.second - input_range.first; - empty_tensors.resize(vector_size); - for (size_t i = 0; i < vector_size; ++i) { - empty_tensors[i] = InitializedEmptyTensor(); - } - VLOG(7) << "Custom operator add output " << output - << " to CustomOpKernelContext. Add vector size = " - << empty_tensors.size(); - ctx.EmplaceBackOutputs(empty_tensors); - continue; - } - } - VLOG(7) << "Custom operator add output " << output - << " to CustomOpKernelContext. Add initialized Tensor because " - "using general or inplace mechanism"; - // general Tensor or inplace Tensor, initialized tensor. - ctx.EmplaceBackOutput(std::move(InitializedEmptyTensor())); - } - // handle inplace map - ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); - VLOG(7) << "Run Kernel of Custom Op: " << op_type; - (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx); - ctx.AssignInplaceOutputs(); + run_custom_op_kernel(); // handle optional None output when construct backward graph for (size_t i = 0; i < ctx.OutputRange().size(); i++) { diff --git a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py index 1242f98356e807..2451071b09da86 100644 --- a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py +++ b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py @@ -91,6 +91,8 @@ def test_body(self, x_shape, x_specs): dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr) dist_x.stop_gradient = False + x = paddle.add(x, x) + dist_x = paddle.add(dist_x, dist_x) out = custom_ops.custom_relu(x) dist_out = custom_ops.custom_relu(dist_x) out.stop_gradient = False From 08f25640ccc0ebd392983c0fc1823f95988f5bdc Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 1 Nov 2023 04:37:57 +0000 Subject: [PATCH 04/23] refine --- paddle/fluid/eager/utils.cc | 33 +++++++ paddle/fluid/pybind/eager_functions.cc | 97 ++++++++++--------- paddle/phi/api/ext/op_meta_info.h | 14 +-- paddle/phi/api/lib/op_meta_info.cc | 16 +-- .../semi_auto_parallel_for_custom_relu.py | 8 +- 5 files changed, 101 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index 5479f9263e477c..8f49f55af362c7 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -336,6 +336,10 @@ void EagerUtils::HandleViewBetweenInputAndOutput( std::dynamic_pointer_cast(input_tensor.impl()); if (view_output_tensor->impl() == nullptr) { view_output_tensor->set_impl(std::make_shared()); + } else { + PADDLE_ENFORCE(view_output_tensor->is_dense_tensor(), + phi::errors::Unavailable( + "DenseTensor can not be inplaced with other Tensor.")); } auto view_output_dense_tensor = std::dynamic_pointer_cast(view_output_tensor->impl()); @@ -343,6 +347,35 @@ void EagerUtils::HandleViewBetweenInputAndOutput( view_output_dense_tensor->ShareInplaceVersionCounterWith( *input_dense_tensor); + VLOG(4) << "Perform View between Output Tensor(" + << view_output_tensor->name() << ") and Input Tensor(" + << input_tensor.name() + << "), share allocation and inplace version."; + } else if (input_tensor.is_dist_tensor()) { + auto input_dense_tensor = + std::dynamic_pointer_cast( + input_tensor.impl()) + ->unsafe_mutable_value(); + if (view_output_tensor->impl() == nullptr) { + view_output_tensor->set_impl( + std::make_shared( + input_tensor.dims(), + std::dynamic_pointer_cast( + input_tensor.impl()) + ->dist_attr())); + } else { + PADDLE_ENFORCE(view_output_tensor->is_dist_tensor(), + phi::errors::Unavailable( + "DistTensor can not be inplaced with other Tensor.")); + } + auto view_output_dense_tensor = + std::dynamic_pointer_cast( + view_output_tensor->impl()) + ->unsafe_mutable_value(); + view_output_dense_tensor->ShareBufferWith(*input_dense_tensor); + view_output_dense_tensor->ShareInplaceVersionCounterWith( + *input_dense_tensor); + VLOG(4) << "Perform View between Output Tensor(" << view_output_tensor->name() << ") and Input Tensor(" << input_tensor.name() diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 1d9d6309341935..2051839cfc7667 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -68,6 +68,7 @@ typedef SSIZE_T ssize_t; #include "paddle/phi/core/flags.h" #ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" #include "paddle/phi/infermeta/spmd_rules/rules.h" #endif @@ -544,7 +545,7 @@ static std::vector> RunDefaultInferShapeFunc( "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); VLOG(3) << "Custom Operator: Default InferShape - share ddim."; - result.emplace_back({ctx.InputAt(0).dims()}); + result.push_back({ctx.InputAt(0).dims()}); } else { // inplace case PADDLE_ENFORCE_EQ( inplace_map.size(), @@ -557,8 +558,9 @@ static std::vector> RunDefaultInferShapeFunc( "`.SetInferShapeFn(PD_INFER_SHAPE(...)`)", outputs.size(), inplace_map.size())); - for (auto const& pair : ctx.GetInplaceIndexMap()) { - if (detail::IsDuplicableVar(inputs[pair.first])) { + auto inplace_index_map = ctx.GetInplaceIndexMap(); + for (auto const& pair : inplace_index_map) { + if (paddle::framework::detail::IsDuplicableVar(inputs[pair.first])) { std::vector shapes; auto duplicable_input_pair = ctx.InputRangeAt(pair.first); for (size_t j = duplicable_input_pair.first; @@ -569,7 +571,7 @@ static std::vector> RunDefaultInferShapeFunc( result.emplace_back(std::move(shapes)); } else { auto duplicable_input_pair = ctx.InputRangeAt(pair.first); - result.emplace_back({ctx.InputAt(duplicable_input_pair.first).dims()}); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dims()}); } } } @@ -633,7 +635,7 @@ static std::vector> RunInferShapeFunc( size_t output_shape_idx = 0; auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); for (size_t i = 0; i < outputs.size(); ++i) { - if (detail::IsDuplicableVar(outputs[i])) { + if (paddle::framework::detail::IsDuplicableVar(outputs[i])) { PADDLE_ENFORCE( inplace_reverse_map.find(i) != inplace_reverse_map.end(), phi::errors::InvalidArgument( @@ -652,13 +654,13 @@ static std::vector> RunInferShapeFunc( } else { if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); - result.emplace_back({ctx.InputAt(duplicable_input_pair.first).dims()}); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dims()}); } else { - result.emplace_back( - {phi::make_ddim(output_shapes[output_shape_idx++])}); + result.push_back({phi::make_ddim(output_shapes[output_shape_idx++])}); } } } + return result; } static std::vector> RunDefaultInferDtypeFunc( @@ -692,7 +694,7 @@ static std::vector> RunDefaultInferDtypeFunc( "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); VLOG(3) << "Custom Operator: InferDtype - share dtype."; - result.emplace_back({ctx.InputAt(0).dtype()}); + result.push_back({ctx.InputAt(0).dtype()}); } else { // inplace case PADDLE_ENFORCE_EQ( inplace_map.size(), @@ -706,7 +708,7 @@ static std::vector> RunDefaultInferDtypeFunc( outputs.size(), inplace_map.size())); for (auto const& pair : ctx.GetInplaceIndexMap()) { - if (detail::IsDuplicableVar(inputs[pair.first])) { + if (paddle::framework::detail::IsDuplicableVar(inputs[pair.first])) { std::vector shapes; auto duplicable_input_pair = ctx.InputRangeAt(pair.first); for (size_t j = duplicable_input_pair.first; @@ -717,7 +719,7 @@ static std::vector> RunDefaultInferDtypeFunc( result.emplace_back(std::move(shapes)); } else { auto duplicable_input_pair = ctx.InputRangeAt(pair.first); - result.emplace_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); } } } @@ -741,9 +743,9 @@ static std::vector> RunInferDtypeFunc( input_dtypes.emplace_back( std::move(ctx.InputAt(input_pair.first).dtype())); } else { - std::vector> dtypes; + std::vector dtypes; for (size_t j = input_pair.first; j < input_pair.second; j++) { - dtypes.push_back(std::move(ctx.InputAt(j).dtype())); + dtypes.emplace_back(ctx.InputAt(j).dtype()); } vec_input_dtypes.emplace_back(std::move(dtypes)); } @@ -781,7 +783,7 @@ static std::vector> RunInferDtypeFunc( size_t output_dtype_idx = 0; auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); for (size_t i = 0; i < outputs.size(); ++i) { - if (detail::IsDuplicableVar(outputs[i])) { + if (paddle::framework::detail::IsDuplicableVar(outputs[i])) { PADDLE_ENFORCE( inplace_reverse_map.find(i) != inplace_reverse_map.end(), phi::errors::InvalidArgument( @@ -800,20 +802,21 @@ static std::vector> RunInferDtypeFunc( } else { if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); - result.emplace_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); } else { - result.emplace_back( - {phi::make_ddim(output_dtypes[output_dtype_idx++])}); + result.push_back({output_dtypes[output_dtype_idx++]}); } } } + return result; } -phi::Tensor BuildEmptyDistPhiTensor(const ProcessMesh& process_mesh, - const phi::DDim& dims, - phi::DataType dtype) { - phi::Tensor empty_tensor; - DenseTensorMeta meta; +paddle::Tensor BuildEmptyDistPaddleTensor( + const phi::distributed::ProcessMesh& process_mesh, + const phi::DDim& dims, + phi::DataType dtype) { + paddle::Tensor empty_tensor; + phi::DenseTensorMeta meta; meta.dims = dims; meta.dtype = dtype; @@ -831,14 +834,14 @@ phi::Tensor BuildEmptyDistPhiTensor(const ProcessMesh& process_mesh, } void run_custom_op_kernel( - const paddle::CustomOpKernelContext& ctx, + paddle::CustomOpKernelContext& ctx, // NOLINT const std::vector& vec_map, const std::vector& inputs, const std::vector& outputs, const std::unordered_map& inplace_map) { bool run_auto_parallel = false; bool rank_is_in_current_mesh = true; - ProcessMesh current_process_mesh; + phi::distributed::ProcessMesh current_process_mesh; std::vector* all_inputs = ctx.AllMutableInput(); #ifdef PADDLE_WITH_DISTRIBUTE @@ -848,17 +851,16 @@ void run_custom_op_kernel( ConvertAllInputsToDistTensor(mesh, x); } - run_auto_parallel = AllInputsAreDistTensor(x); + run_auto_parallel = paddle::experimental::AllInputsAreDistTensor(x); rank_is_in_current_mesh = true; if (run_auto_parallel) { - for (size_t i = 0; i < all_inputs.size(); ++i) { - PADDLE_ENFORCE_EQ(all_inputs->at(i).initialized() && - all_inputs->at(i).is_dense_tensor() && - all_inputs->at(i).is_gpu(), - true, - phi::errors::InvalidArgument( - "The custom op's input tensor must be initialized " - "dense tensor on gpu, in AutoParallel mode.")); + for (size_t i = 0; i < all_inputs->size(); ++i) { + PADDLE_ENFORCE_EQ( + all_inputs->at(i).initialized() && all_inputs->at(i).is_gpu(), + true, + phi::errors::InvalidArgument( + "The custom op's input tensor must be initialized " + "tensor on gpu, in AutoParallel mode.")); } auto mesh = @@ -872,18 +874,20 @@ void run_custom_op_kernel( input_x[i] = x.at(i).impl().get(); } - auto meta_dist_input_x = MakeDistMetaTensor(input_x); + auto meta_dist_input_x = paddle::experimental::MakeDistMetaTensor(input_x); auto spmd_info = phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x); current_process_mesh = spmd_info.first[0].process_mesh(); if (rank_is_in_current_mesh) { - auto* dev_ctx = static_cast(pool.Get(x.at(0).place())); + auto* dev_ctx = static_cast( + phi::DeviceContextPool::Instance().Get(x.at(0).place())); auto dist_input_x = - ReshardApiInputToReplicatedKernelInput(dev_ctx, x, spmd_info.first); + paddle::experimental::ReshardApiInputToReplicatedKernelInput( + dev_ctx, x, spmd_info.first); for (size_t i = 0; i < x.size(); ++i) { - *static_cast(all_inputs->at(i).impl().get()) = - *(dist_input_x[i]->unsafe_mutable_value()); + all_inputs->at(i).set_impl(std::make_shared( + *(dist_input_x[i]->unsafe_mutable_value()))); } } else { auto& infer_shape_func = @@ -928,13 +932,13 @@ void run_custom_op_kernel( if (out_dim.size() == 0) { ctx.EmplaceBackOutput(std::move(paddle::Tensor())); } else if (out_dim.size() == 1) { - ctx.EmplaceBackOutput(std::move(BuildEmptyDistPhiTensor( + ctx.EmplaceBackOutput(std::move(BuildEmptyDistPaddleTensor( current_process_mesh, out_dim[0], out_dtype[0]))); } else { std::vector out_tensors; - out_tensors.reverse(out_dim.size()); + out_tensors.reserve(out_dim.size()); for (size_t j = 0; j < out_dim.size(); ++j) { - out_tensors.emplace_back(BuildEmptyDistPhiTensor( + out_tensors.emplace_back(BuildEmptyDistPaddleTensor( current_process_mesh, out_dim[j], out_dtype[j])); } ctx.EmplaceBackOutputs(out_tensors); @@ -999,7 +1003,7 @@ void run_custom_op_kernel( // handle inplace map ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); - VLOG(7) << "Run Kernel of Custom Op: " << op_type; + VLOG(7) << "Begin run Kernel of Custom Op"; (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx); ctx.AssignInplaceOutputs(); @@ -1008,11 +1012,12 @@ void run_custom_op_kernel( std::vector* output_all = ctx.AllMutableOutput(); for (size_t i = 0; i < output_all->size(); ++i) { auto& tensor = output_all->at(i); - auto dist_attr = + phi::distributed::TensorDistAttr dist_attr = phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); dist_attr.set_process_mesh(current_process_mesh); auto dist_t = std::make_shared( - tensor.impl(), dist_attr); + std::dynamic_pointer_cast(tensor.impl()), + dist_attr); tensor.set_impl(dist_t); } } @@ -1125,8 +1130,8 @@ static PyObject* eager_api_run_custom_op(PyObject* self, { eager_gil_scoped_release guard; - - run_custom_op_kernel(); + VLOG(7) << "Run Kernel of Custom Op: " << op_type; + run_custom_op_kernel(ctx, vec_map, inputs, outputs, inplace_map); // handle optional None output when construct backward graph for (size_t i = 0; i < ctx.OutputRange().size(); i++) { diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index c774cafcfd26a8..484ea069446537 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -120,15 +120,15 @@ class PADDLE_API CustomOpKernelContext { std::vector InputsBetween(size_t start, size_t end) const; Tensor& MutableInputAt(size_t idx); std::vector* AllMutableInput(); - paddle::optional OptionalInputAt(size_t idx); + paddle::optional OptionalInputAt(size_t idx) const; paddle::optional> OptionalInputsBetween(size_t start, - size_t end); + size_t end) const; const std::vector& Attrs() const; - const std::vector>& InputRange(); - const std::vector>& OutputRange(); + const std::vector>& InputRange() const; + const std::vector>& OutputRange() const; Tensor* MutableOutputAt(size_t idx); std::vector MutableOutputBetween(size_t start, size_t end); - std::vector OutputsBetween(size_t start, size_t end); + std::vector OutputsBetween(size_t start, size_t end) const; std::vector* AllMutableOutput(); template @@ -151,8 +151,8 @@ class PADDLE_API CustomOpKernelContext { const std::unordered_map& inplace_map); void AssignInplaceOutputs(); std::vector* AllMutablePlainOutput(); - std::unordered_map GetInplaceIndexMap(); - std::unordered_map GetInplaceReverseIndexMap(); + std::unordered_map GetInplaceIndexMap() const; + std::unordered_map GetInplaceReverseIndexMap() const; private: // TODO(chenweihang): replaced be SmallVector diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index bf6bb65c625adf..14334aa7c42a6d 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -162,7 +162,8 @@ std::vector* CustomOpKernelContext::AllMutableInput() { return &inputs_; } -paddle::optional CustomOpKernelContext::OptionalInputAt(size_t idx) { +paddle::optional CustomOpKernelContext::OptionalInputAt( + size_t idx) const { if (!inputs_.at(idx).is_initialized()) { return paddle::none; } @@ -170,7 +171,7 @@ paddle::optional CustomOpKernelContext::OptionalInputAt(size_t idx) { } paddle::optional> -CustomOpKernelContext::OptionalInputsBetween(size_t start, size_t end) { +CustomOpKernelContext::OptionalInputsBetween(size_t start, size_t end) const { std::vector rlt; for (size_t i = start; i < end; ++i) { if (!inputs_.at(i).is_initialized()) { @@ -194,7 +195,7 @@ std::vector CustomOpKernelContext::MutableOutputBetween(size_t start, } std::vector CustomOpKernelContext::OutputsBetween(size_t start, - size_t end) { + size_t end) const { std::vector rlt; for (size_t i = start; i < end; ++i) { rlt.emplace_back(outputs_.at(i)); @@ -216,12 +217,12 @@ const std::pair& CustomOpKernelContext::OutputRangeAt( } const std::vector>& -CustomOpKernelContext::InputRange() { +CustomOpKernelContext::InputRange() const { return input_range_; } const std::vector>& -CustomOpKernelContext::OutputRange() { +CustomOpKernelContext::OutputRange() const { return output_range_; } @@ -306,12 +307,13 @@ std::vector* CustomOpKernelContext::AllMutablePlainOutput() { return &plain_outputs_; } -std::unordered_map CustomOpKernelContext::GetInplaceIndexMap() { +std::unordered_map CustomOpKernelContext::GetInplaceIndexMap() + const { return inplace_idx_map_; } std::unordered_map -CustomOpKernelContext::GetInplaceReverseIndexMap() { +CustomOpKernelContext::GetInplaceReverseIndexMap() const { return inplace_reverse_idx_map_; } ////////////////////// Op Meta Info ////////////////////// diff --git a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py index 2451071b09da86..b7150292403e39 100644 --- a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py +++ b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py @@ -111,13 +111,7 @@ def test_custom_relu(self): ) def run_test_case(self): - if self._backend == "cpu": - paddle.set_device("cpu") - elif self._backend == "gpu": - paddle.set_device("gpu:" + str(dist.get_rank())) - else: - raise ValueError("Only support cpu or gpu backend.") - + paddle.set_device("gpu:" + str(dist.get_rank())) self.test_custom_relu() From ae6d1b14766ecd94080043d44ae1970230501430 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 1 Nov 2023 06:25:51 +0000 Subject: [PATCH 05/23] refine --- paddle/fluid/pybind/eager_functions.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 2051839cfc7667..4abed06a2ce307 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -1020,6 +1020,17 @@ void run_custom_op_kernel( dist_attr); tensor.set_impl(dist_t); } + std::vector* input_all = ctx.AllMutableInput(); + for (size_t i = 0; i < input_all->size(); ++i) { + auto& tensor = input_all->at(i); + phi::distributed::TensorDistAttr dist_attr = + phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); + dist_attr.set_process_mesh(current_process_mesh); + auto dist_t = std::make_shared( + std::dynamic_pointer_cast(tensor.impl()), + dist_attr); + tensor.set_impl(dist_t); + } } #endif } From d28838e2dc5fb5242d04b87f6395b3fea2051737 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 1 Nov 2023 10:37:16 +0000 Subject: [PATCH 06/23] refine --- .../eager/custom_operator/CMakeLists.txt | 5 + .../custom_operator/custom_operator_node.cc | 27 +- .../custom_operator_run_kernel_impl.cc | 517 ++++++++++++++++ .../custom_operator_run_kernel_impl.h | 22 + paddle/fluid/pybind/eager_functions.cc | 568 ++---------------- .../semi_auto_parallel_for_custom_relu.py | 8 +- 6 files changed, 605 insertions(+), 542 deletions(-) create mode 100644 paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc create mode 100644 paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h diff --git a/paddle/fluid/eager/custom_operator/CMakeLists.txt b/paddle/fluid/eager/custom_operator/CMakeLists.txt index a2648d3e325568..88de909ec4965c 100644 --- a/paddle/fluid/eager/custom_operator/CMakeLists.txt +++ b/paddle/fluid/eager/custom_operator/CMakeLists.txt @@ -1,4 +1,9 @@ cc_library( custom_operator_node SRCS custom_operator_node.cc + DEPS phi grad_node_info custom_operator utils custom_operator_run_kernel_impl) + +cc_library( + custom_operator_run_kernel_impl + SRCS custom_operator_run_kernel_impl.cc DEPS phi grad_node_info custom_operator utils) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index 5643c0e69391f0..089a0321f175d4 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/eager/custom_operator/custom_operator_node.h" +#include "paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h" #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator_utils.h" #include "paddle/fluid/platform/profiler/event_tracing.h" @@ -172,8 +173,6 @@ RunCustomOpNode::operator()(paddle::small_vector, paddle::OpMetaInfoHelper::GetInputs(vec_map[1]); const auto& grad_outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[1]); - const auto& grad_inplace_map = - paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[1]); const auto& map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); @@ -251,11 +250,12 @@ RunCustomOpNode::operator()(paddle::small_vector, } VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad"; - // handle inplace map - ctx.UpdatePlainOutputs( - grad_inputs_name, grad_outputs_names, grad_inplace_map); - (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[1]))(&ctx); - ctx.AssignInplaceOutputs(); + run_custom_op_impl(vec_map[1], ctx); + + for (size_t i = 0; i < ctx.OutputRange().size(); ++i) { + auto output_pair = ctx.OutputRangeAt(i); + outs[i] = ctx.OutputsBetween(output_pair.first, output_pair.second); + } // handle optional None output when construct backward graph for (size_t i = 0; i < ctx.OutputRange().size(); i++) { @@ -386,8 +386,6 @@ RunCustomOpDoubleGradNode::operator()( paddle::OpMetaInfoHelper::GetInputs(vec_map[2]); const auto& grad_outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[2]); - const auto& grad_inplace_map = - paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[2]); const auto& map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); @@ -451,11 +449,12 @@ RunCustomOpDoubleGradNode::operator()( } VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad_grad"; - // handle inplace map - ctx.UpdatePlainOutputs( - grad_inputs_name, grad_outputs_names, grad_inplace_map); - (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[2]))(&ctx); - ctx.AssignInplaceOutputs(); + run_custom_op_impl(vec_map[2], ctx); + + for (size_t i = 0; i < ctx.OutputRange().size(); ++i) { + auto output_pair = ctx.OutputRangeAt(i); + outs[i] = ctx.OutputsBetween(output_pair.first, output_pair.second); + } return outs; } diff --git a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc new file mode 100644 index 00000000000000..d022bac23abddd --- /dev/null +++ b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc @@ -0,0 +1,517 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h" + +#include "paddle/fluid/framework/custom_operator.h" +#include "paddle/fluid/framework/custom_operator_utils.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/pybind/eager_utils.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/api/lib/data_transform.h" +#include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/flags.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/api/lib/api_gen_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif + +namespace egr { + +using Tensor = paddle::Tensor; + +static std::vector> RunDefaultInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + if (inplace_map.empty()) { // general case, assure single input and output + PADDLE_ENFORCE_EQ( + inputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple inputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" + "Please set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple outputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" + "Please set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + + VLOG(3) << "Custom Operator: Default InferShape - share ddim."; + result.push_back({ctx.InputAt(0).dims()}); + } else { // inplace case + PADDLE_ENFORCE_EQ( + inplace_map.size(), + outputs.size(), + phi::errors::Unavailable( + "Your custom operator uses `SetInplaceMap` without setting the " + "InferShapeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " + "the InferShapeFn of custom operator by " + "`.SetInferShapeFn(PD_INFER_SHAPE(...)`)", + outputs.size(), + inplace_map.size())); + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(inputs[i])) { + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(i); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dims()); + } + result.emplace_back(std::move(shapes)); + } else { + auto duplicable_input_pair = ctx.InputRangeAt(i); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dims()}); + } + } + } + return result; +} + +static std::vector> RunInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const paddle::InferShapeFunc& func, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + std::vector> input_shapes; + std::vector>> vec_input_shapes; + + VLOG(3) << "Custom Operator: InferShape - get input ddim."; + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + const auto& input_pair = ctx.InputRangeAt(i); + if (input_pair.first == input_pair.second) { + input_shapes.emplace_back( + std::move(ctx.InputAt(input_pair.first).shape())); + } else { + std::vector> shapes; + for (size_t j = input_pair.first; j < input_pair.second; j++) { + shapes.push_back(std::move(ctx.InputAt(j).shape())); + } + vec_input_shapes.emplace_back(std::move(shapes)); + } + } + + VLOG(3) << "Custom Operator: InferShape - calc output ddim."; + auto output_shapes = func(input_shapes, vec_input_shapes, ctx.Attrs()); + if (inplace_map.empty()) { + PADDLE_ENFORCE_EQ(outputs.size(), + output_shapes.size(), + phi::errors::InvalidArgument( + "Your custom operator has set the InferShapeFn. " + "However, `Outputs` size = %d does not match the " + "returned vector size of InferShapeFn = %d. Please " + "check InferShapeFn again.", + outputs.size(), + output_shapes.size())); + } else { + PADDLE_ENFORCE_EQ( + outputs.size(), + output_shapes.size() + inplace_map.size(), + phi::errors::InvalidArgument( + "Your custom operator uses `SetInplaceMap` and sets the " + "InferShapeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap size + InferShapeFn output size` = %d. Please check " + "InplaceMap and InferShapeFn again", + outputs.size(), + output_shapes.size() + inplace_map.size())); + } + + VLOG(3) + << "Custom Operator: InferShape - set output ddim: inplace_map.size() = " + << inplace_map.size() + << ", output_shapes.size() = " << output_shapes.size(); + size_t output_shape_idx = 0; + auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); + for (size_t i = 0; i < outputs.size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(outputs[i])) { + PADDLE_ENFORCE( + inplace_reverse_map.find(i) != inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Custom operator only supports `paddle::Vec(...)` inputs and " + "cannot support `paddle::Vec(...)` output without setting " + "InplaceMap. If you have to use `paddle::Vec(...)` output, " + "please indicate it by setting InplaceMap manully.")); + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dims()); + } + result.emplace_back(std::move(shapes)); + } else { + if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dims()}); + } else { + result.push_back({phi::make_ddim(output_shapes[output_shape_idx++])}); + } + } + } + return result; +} + +static std::vector> RunDefaultInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + if (inplace_map.empty()) { // general case, assure single input and output + PADDLE_ENFORCE_EQ( + inputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple inputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" + "Please set the InferDtypeFn of custom " + "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple outputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" + "Please set the InferDtypeFn of custom " + "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); + + VLOG(3) << "Custom Operator: InferDtype - share dtype."; + result.push_back({ctx.InputAt(0).dtype()}); + } else { // inplace case + PADDLE_ENFORCE_EQ( + inplace_map.size(), + outputs.size(), + phi::errors::Unavailable( + "Your custom operator uses `SetInplaceMap` without setting the " + "InferDtypeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " + "the InferDtypeFn of custom operator by " + "`.SetInferDtypeFn(PD_INFER_DTYPE(...))`", + outputs.size(), + inplace_map.size())); + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(inputs[i])) { + std::vector shapes; + auto duplicable_input_pair = ctx.InputRangeAt(i); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + shapes.push_back(ctx.InputAt(j).dtype()); + } + result.emplace_back(std::move(shapes)); + } else { + auto duplicable_input_pair = ctx.InputRangeAt(i); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); + } + } + } + return result; +} + +static std::vector> RunInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const paddle::InferDtypeFunc& func, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + std::vector> result; + std::vector input_dtypes; + std::vector> vec_input_dtypes; + + VLOG(3) << "Custom Operator: InferDtype - get input dtype."; + for (size_t i = 0; i < ctx.InputRange().size(); ++i) { + const auto& input_pair = ctx.InputRangeAt(i); + if (input_pair.first == input_pair.second) { + input_dtypes.emplace_back( + std::move(ctx.InputAt(input_pair.first).dtype())); + } else { + std::vector dtypes; + for (size_t j = input_pair.first; j < input_pair.second; j++) { + dtypes.emplace_back(ctx.InputAt(j).dtype()); + } + vec_input_dtypes.emplace_back(std::move(dtypes)); + } + } + + VLOG(3) << "Custom Operator: InferDtype - infer output dtype."; + auto output_dtypes = func(input_dtypes, vec_input_dtypes, ctx.Attrs()); + if (inplace_map.empty()) { + PADDLE_ENFORCE_EQ(outputs.size(), + output_dtypes.size(), + phi::errors::InvalidArgument( + "Your custom operator has set the InferDtypeFn. " + "However, `Outputs` size = %d does not match the " + "returned vector size of InferDtypeFn = %d. Please " + "check InferDtypeFn again.", + outputs.size(), + output_dtypes.size())); + } else { + PADDLE_ENFORCE_EQ( + outputs.size(), + output_dtypes.size() + inplace_map.size(), + phi::errors::InvalidArgument( + "Your custom operator uses `SetInplaceMap` and sets the " + "InferDtypeFn. However, `Outputs` size = %d does not match the " + "`InplaceMap size + InferDtypeFn output size` = %d. Please check " + "InplaceMap and InferDtypeFn again", + outputs.size(), + output_dtypes.size() + inplace_map.size())); + } + + VLOG(3) + << "Custom Operator: InferDtype - set output dtype: inplace_map.size() = " + << inplace_map.size() + << ", output_dtypes.size() = " << output_dtypes.size(); + size_t output_dtype_idx = 0; + auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); + for (size_t i = 0; i < outputs.size(); ++i) { + if (paddle::framework::detail::IsDuplicableVar(outputs[i])) { + PADDLE_ENFORCE( + inplace_reverse_map.find(i) != inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Custom operator only supports `paddle::Vec(...)` inputs and " + "cannot support `paddle::Vec(...)` output without setting " + "InplaceMap. If you have to use `paddle::Vec(...)` output, " + "please indicate it by setting InplaceMap manully.")); + std::vector dtypes; + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + for (size_t j = duplicable_input_pair.first; + j < duplicable_input_pair.second; + j++) { + dtypes.push_back(ctx.InputAt(j).dtype()); + } + result.emplace_back(std::move(dtypes)); + } else { + if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { + auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); + result.push_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); + } else { + result.push_back({output_dtypes[output_dtype_idx++]}); + } + } + } + return result; +} + +paddle::Tensor BuildEmptyDistPaddleTensor( + const phi::distributed::ProcessMesh& process_mesh, + const phi::DDim& dims, + phi::DataType dtype) { + paddle::Tensor empty_tensor; + phi::DenseTensorMeta meta; + meta.dims = dims; + meta.dtype = dtype; + + auto dist_attr = phi::distributed::TensorDistAttr(phi::vectorize(dims)); + dist_attr.set_process_mesh(process_mesh); + + auto dist_t = std::make_shared( + std::make_shared( + std::make_shared( + nullptr, 0, phi::distributed::GetDefaultPlace()), + meta), + dist_attr); + empty_tensor.set_impl(dist_t); + return empty_tensor; +} + +void run_custom_op_impl(paddle::OpMetaInfo op_info, + paddle::CustomOpKernelContext& ctx) { // NOLINT + const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); + const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(op_info); + ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); + + bool run_auto_parallel = false; + bool rank_is_in_current_mesh = true; + phi::distributed::ProcessMesh current_process_mesh; + std::vector* all_inputs = ctx.AllMutableInput(); + +#ifdef PADDLE_WITH_DISTRIBUTE + std::vector x = *all_inputs; + const phi::distributed::ProcessMesh* mesh = nullptr; + if (paddle::pybind::InputsContainDistTensor(&mesh, x)) { + paddle::pybind::ConvertAllInputsToDistTensor(mesh, x); + } + + run_auto_parallel = paddle::experimental::AllInputsAreDistTensor(x); + rank_is_in_current_mesh = true; + if (run_auto_parallel) { + for (size_t i = 0; i < all_inputs->size(); ++i) { + PADDLE_ENFORCE_EQ( + all_inputs->at(i).initialized() && all_inputs->at(i).is_gpu(), + true, + phi::errors::InvalidArgument( + "The custom op's input tensor must be initialized " + "tensor on gpu, in AutoParallel mode.")); + } + + auto mesh = + std::static_pointer_cast(x.at(0).impl()) + ->dist_attr() + .process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + std::vector input_x(x.size()); + for (size_t i = 0; i < input_x.size(); ++i) { + input_x[i] = x.at(i).impl().get(); + } + + auto meta_dist_input_x = paddle::experimental::MakeDistMetaTensor(input_x); + auto spmd_info = + phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x); + current_process_mesh = spmd_info.first[0].process_mesh(); + + if (rank_is_in_current_mesh) { + auto* dev_ctx = static_cast( + phi::DeviceContextPool::Instance().Get(x.at(0).place())); + auto dist_input_x = + paddle::experimental::ReshardApiInputToReplicatedKernelInput( + dev_ctx, x, spmd_info.first); + for (size_t i = 0; i < x.size(); ++i) { + all_inputs->at(i).set_impl(std::make_shared( + *(dist_input_x[i]->unsafe_mutable_value()))); + } + } else { + auto& infer_shape_func = + paddle::OpMetaInfoHelper::GetInferShapeFn(op_info); + auto& infer_dtype_func = + paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info); + + std::vector> out_dims; + if (infer_shape_func) { + out_dims = RunInferShapeFunc( + ctx, infer_shape_func, inputs, outputs, inplace_map); + } else { + out_dims = RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); + } + + std::vector> out_dtypes; + if (infer_dtype_func) { + out_dtypes = RunInferDtypeFunc( + ctx, infer_dtype_func, inputs, outputs, inplace_map); + } else { + out_dtypes = + RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); + } + + PADDLE_ENFORCE_EQ( + out_dims.size(), + out_dtypes.size(), + phi::errors::InvalidArgument("custome op infer_shape and infer_dtype " + "must have the same output size.")); + + for (size_t i = 0; i < out_dims.size(); ++i) { + const auto& out_dim = out_dims.at(i); + const auto& out_dtype = out_dtypes.at(i); + PADDLE_ENFORCE_EQ( + out_dim.size(), + out_dtype.size(), + phi::errors::InvalidArgument( + "custome op infer_shape result[%d] and infer_dtype result[%d] " + "must have the same output size.", + i, + i)); + if (out_dim.size() == 0) { + ctx.EmplaceBackOutput(std::move(paddle::Tensor())); + } else if (out_dim.size() == 1) { + ctx.EmplaceBackOutput(std::move(BuildEmptyDistPaddleTensor( + current_process_mesh, out_dim[0], out_dtype[0]))); + } else { + std::vector out_tensors; + out_tensors.reserve(out_dim.size()); + for (size_t j = 0; j < out_dim.size(); ++j) { + out_tensors.emplace_back(BuildEmptyDistPaddleTensor( + current_process_mesh, out_dim[j], out_dtype[j])); + } + ctx.EmplaceBackOutputs(out_tensors); + } + } + return; + } + } +#endif + + for (size_t i = 0; i < all_inputs->size(); ++i) { + auto& tensor = all_inputs->at(i); + if (tensor.initialized() && tensor.is_dense_tensor() && + !std::dynamic_pointer_cast(tensor.impl()) + ->meta() + .is_contiguous()) { + tensor.set_impl(std::make_shared( + std::move(paddle::experimental::Trans2Contiguous( + *(std::dynamic_pointer_cast(tensor.impl())))))); + } + } + + // handle inplace map + ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); + VLOG(7) << "Begin run Kernel of Custom Op"; + (*paddle::OpMetaInfoHelper::GetKernelFn(op_info))(&ctx); + ctx.AssignInplaceOutputs(); + +#ifdef PADDLE_WITH_DISTRIBUTE + if (run_auto_parallel) { + std::vector* output_all = ctx.AllMutableOutput(); + for (size_t i = 0; i < output_all->size(); ++i) { + auto& tensor = output_all->at(i); + phi::distributed::TensorDistAttr dist_attr = + phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); + dist_attr.set_process_mesh(current_process_mesh); + auto dist_t = std::make_shared( + std::dynamic_pointer_cast(tensor.impl()), + dist_attr); + tensor.set_impl(dist_t); + } + std::vector* input_all = ctx.AllMutableInput(); + for (size_t i = 0; i < input_all->size(); ++i) { + auto& tensor = input_all->at(i); + phi::distributed::TensorDistAttr dist_attr = + phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); + dist_attr.set_process_mesh(current_process_mesh); + auto dist_t = std::make_shared( + std::dynamic_pointer_cast(tensor.impl()), + dist_attr); + tensor.set_impl(dist_t); + } + } +#endif +} + +} // namespace egr diff --git a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h new file mode 100644 index 00000000000000..4dfac34141aec1 --- /dev/null +++ b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h @@ -0,0 +1,22 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/api/ext/op_meta_info.h" + +namespace egr { +void run_custom_op_impl(paddle::OpMetaInfo op_info, + paddle::CustomOpKernelContext& ctx); // NOLINT +} // namespace egr diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 49d5135a9b95a0..e170c335b30051 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -62,11 +62,11 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/pybind/cuda_streams_py.h" #endif +#include "paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h" #include "paddle/phi/api/include/operants_manager.h" #include "paddle/phi/api/include/tensor_operants.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/core/flags.h" - #ifdef PADDLE_WITH_DISTRIBUTE #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" @@ -514,527 +514,6 @@ static Tensor InitializedEmptyTensor() { return tensor; } -static std::vector> RunDefaultInferShapeFunc( - const paddle::CustomOpKernelContext& ctx, - const std::vector& inputs, - const std::vector& outputs, - const std::unordered_map& inplace_map) { - std::vector> result; - if (inplace_map.empty()) { // general case, assure single input and output - PADDLE_ENFORCE_EQ( - inputs.size(), - 1UL, - phi::errors::Unavailable( - "Your custom operator contains multiple inputs. " - "We only allow a custom operator that contains only one input " - "and only one output without setting the InferShapeFn. " - "At this time, the input shape will be directly set to " - "the output shape.\n" - "Please set the InferShapeFn of custom " - "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); - PADDLE_ENFORCE_EQ( - outputs.size(), - 1UL, - phi::errors::Unavailable( - "Your custom operator contains multiple outputs. " - "We only allow a custom operator that contains only one input " - "and only one output without setting the InferShapeFn. " - "At this time, the input shape will be directly set to " - "the output shape.\n" - "Please set the InferShapeFn of custom " - "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); - - VLOG(3) << "Custom Operator: Default InferShape - share ddim."; - result.push_back({ctx.InputAt(0).dims()}); - } else { // inplace case - PADDLE_ENFORCE_EQ( - inplace_map.size(), - outputs.size(), - phi::errors::Unavailable( - "Your custom operator uses `SetInplaceMap` without setting the " - "InferShapeFn. However, `Outputs` size = %d does not match the " - "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " - "the InferShapeFn of custom operator by " - "`.SetInferShapeFn(PD_INFER_SHAPE(...)`)", - outputs.size(), - inplace_map.size())); - auto inplace_index_map = ctx.GetInplaceIndexMap(); - for (auto const& pair : inplace_index_map) { - if (paddle::framework::detail::IsDuplicableVar(inputs[pair.first])) { - std::vector shapes; - auto duplicable_input_pair = ctx.InputRangeAt(pair.first); - for (size_t j = duplicable_input_pair.first; - j < duplicable_input_pair.second; - j++) { - shapes.push_back(ctx.InputAt(j).dims()); - } - result.emplace_back(std::move(shapes)); - } else { - auto duplicable_input_pair = ctx.InputRangeAt(pair.first); - result.push_back({ctx.InputAt(duplicable_input_pair.first).dims()}); - } - } - } - return result; -} - -static std::vector> RunInferShapeFunc( - const paddle::CustomOpKernelContext& ctx, - const paddle::InferShapeFunc& func, - const std::vector& inputs, - const std::vector& outputs, - const std::unordered_map& inplace_map) { - std::vector> result; - std::vector> input_shapes; - std::vector>> vec_input_shapes; - - VLOG(3) << "Custom Operator: InferShape - get input ddim."; - for (size_t i = 0; i < ctx.InputRange().size(); ++i) { - const auto& input_pair = ctx.InputRangeAt(i); - if (input_pair.first == input_pair.second) { - input_shapes.emplace_back( - std::move(ctx.InputAt(input_pair.first).shape())); - } else { - std::vector> shapes; - for (size_t j = input_pair.first; j < input_pair.second; j++) { - shapes.push_back(std::move(ctx.InputAt(j).shape())); - } - vec_input_shapes.emplace_back(std::move(shapes)); - } - } - - VLOG(3) << "Custom Operator: InferShape - calc output ddim."; - auto output_shapes = func(input_shapes, vec_input_shapes, ctx.Attrs()); - if (inplace_map.empty()) { - PADDLE_ENFORCE_EQ(outputs.size(), - output_shapes.size(), - phi::errors::InvalidArgument( - "Your custom operator has set the InferShapeFn. " - "However, `Outputs` size = %d does not match the " - "returned vector size of InferShapeFn = %d. Please " - "check InferShapeFn again.", - outputs.size(), - output_shapes.size())); - } else { - PADDLE_ENFORCE_EQ( - outputs.size(), - output_shapes.size() + inplace_map.size(), - phi::errors::InvalidArgument( - "Your custom operator uses `SetInplaceMap` and sets the " - "InferShapeFn. However, `Outputs` size = %d does not match the " - "`InplaceMap size + InferShapeFn output size` = %d. Please check " - "InplaceMap and InferShapeFn again", - outputs.size(), - output_shapes.size() + inplace_map.size())); - } - - VLOG(3) - << "Custom Operator: InferShape - set output ddim: inplace_map.size() = " - << inplace_map.size() - << ", output_shapes.size() = " << output_shapes.size(); - size_t output_shape_idx = 0; - auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); - for (size_t i = 0; i < outputs.size(); ++i) { - if (paddle::framework::detail::IsDuplicableVar(outputs[i])) { - PADDLE_ENFORCE( - inplace_reverse_map.find(i) != inplace_reverse_map.end(), - phi::errors::InvalidArgument( - "Custom operator only supports `paddle::Vec(...)` inputs and " - "cannot support `paddle::Vec(...)` output without setting " - "InplaceMap. If you have to use `paddle::Vec(...)` output, " - "please indicate it by setting InplaceMap manully.")); - std::vector shapes; - auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); - for (size_t j = duplicable_input_pair.first; - j < duplicable_input_pair.second; - j++) { - shapes.push_back(ctx.InputAt(j).dims()); - } - result.emplace_back(std::move(shapes)); - } else { - if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { - auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); - result.push_back({ctx.InputAt(duplicable_input_pair.first).dims()}); - } else { - result.push_back({phi::make_ddim(output_shapes[output_shape_idx++])}); - } - } - } - return result; -} - -static std::vector> RunDefaultInferDtypeFunc( - const paddle::CustomOpKernelContext& ctx, - const std::vector& inputs, - const std::vector& outputs, - const std::unordered_map& inplace_map) { - std::vector> result; - if (inplace_map.empty()) { // general case, assure single input and output - PADDLE_ENFORCE_EQ( - inputs.size(), - 1UL, - platform::errors::Unavailable( - "Your custom operator contains multiple inputs. " - "We only allow a custom operator that contains only one input " - "and only one output without setting the InferDtypeFn. " - "At this time, the input dtype will be directly set to " - "the output dtype.\n" - "Please set the InferDtypeFn of custom " - "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); - PADDLE_ENFORCE_EQ( - outputs.size(), - 1UL, - platform::errors::Unavailable( - "Your custom operator contains multiple outputs. " - "We only allow a custom operator that contains only one input " - "and only one output without setting the InferDtypeFn. " - "At this time, the input dtype will be directly set to " - "the output dtype.\n" - "Please set the InferDtypeFn of custom " - "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); - - VLOG(3) << "Custom Operator: InferDtype - share dtype."; - result.push_back({ctx.InputAt(0).dtype()}); - } else { // inplace case - PADDLE_ENFORCE_EQ( - inplace_map.size(), - outputs.size(), - phi::errors::Unavailable( - "Your custom operator uses `SetInplaceMap` without setting the " - "InferDtypeFn. However, `Outputs` size = %d does not match the " - "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " - "the InferDtypeFn of custom operator by " - "`.SetInferDtypeFn(PD_INFER_DTYPE(...))`", - outputs.size(), - inplace_map.size())); - for (auto const& pair : ctx.GetInplaceIndexMap()) { - if (paddle::framework::detail::IsDuplicableVar(inputs[pair.first])) { - std::vector shapes; - auto duplicable_input_pair = ctx.InputRangeAt(pair.first); - for (size_t j = duplicable_input_pair.first; - j < duplicable_input_pair.second; - j++) { - shapes.push_back(ctx.InputAt(j).dtype()); - } - result.emplace_back(std::move(shapes)); - } else { - auto duplicable_input_pair = ctx.InputRangeAt(pair.first); - result.push_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); - } - } - } - return result; -} - -static std::vector> RunInferDtypeFunc( - const paddle::CustomOpKernelContext& ctx, - const paddle::InferDtypeFunc& func, - const std::vector& inputs, - const std::vector& outputs, - const std::unordered_map& inplace_map) { - std::vector> result; - std::vector input_dtypes; - std::vector> vec_input_dtypes; - - VLOG(3) << "Custom Operator: InferDtype - get input dtype."; - for (size_t i = 0; i < ctx.InputRange().size(); ++i) { - const auto& input_pair = ctx.InputRangeAt(i); - if (input_pair.first == input_pair.second) { - input_dtypes.emplace_back( - std::move(ctx.InputAt(input_pair.first).dtype())); - } else { - std::vector dtypes; - for (size_t j = input_pair.first; j < input_pair.second; j++) { - dtypes.emplace_back(ctx.InputAt(j).dtype()); - } - vec_input_dtypes.emplace_back(std::move(dtypes)); - } - } - - VLOG(3) << "Custom Operator: InferDtype - infer output dtype."; - auto output_dtypes = func(input_dtypes, vec_input_dtypes, ctx.Attrs()); - if (inplace_map.empty()) { - PADDLE_ENFORCE_EQ(outputs.size(), - output_dtypes.size(), - phi::errors::InvalidArgument( - "Your custom operator has set the InferDtypeFn. " - "However, `Outputs` size = %d does not match the " - "returned vector size of InferDtypeFn = %d. Please " - "check InferDtypeFn again.", - outputs.size(), - output_dtypes.size())); - } else { - PADDLE_ENFORCE_EQ( - outputs.size(), - output_dtypes.size() + inplace_map.size(), - phi::errors::InvalidArgument( - "Your custom operator uses `SetInplaceMap` and sets the " - "InferDtypeFn. However, `Outputs` size = %d does not match the " - "`InplaceMap size + InferDtypeFn output size` = %d. Please check " - "InplaceMap and InferDtypeFn again", - outputs.size(), - output_dtypes.size() + inplace_map.size())); - } - - VLOG(3) - << "Custom Operator: InferDtype - set output dtype: inplace_map.size() = " - << inplace_map.size() - << ", output_dtypes.size() = " << output_dtypes.size(); - size_t output_dtype_idx = 0; - auto inplace_reverse_map = ctx.GetInplaceReverseIndexMap(); - for (size_t i = 0; i < outputs.size(); ++i) { - if (paddle::framework::detail::IsDuplicableVar(outputs[i])) { - PADDLE_ENFORCE( - inplace_reverse_map.find(i) != inplace_reverse_map.end(), - phi::errors::InvalidArgument( - "Custom operator only supports `paddle::Vec(...)` inputs and " - "cannot support `paddle::Vec(...)` output without setting " - "InplaceMap. If you have to use `paddle::Vec(...)` output, " - "please indicate it by setting InplaceMap manully.")); - std::vector dtypes; - auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); - for (size_t j = duplicable_input_pair.first; - j < duplicable_input_pair.second; - j++) { - dtypes.push_back(ctx.InputAt(j).dtype()); - } - result.emplace_back(std::move(dtypes)); - } else { - if (inplace_reverse_map.find(i) != inplace_reverse_map.end()) { - auto duplicable_input_pair = ctx.InputRangeAt(inplace_reverse_map[i]); - result.push_back({ctx.InputAt(duplicable_input_pair.first).dtype()}); - } else { - result.push_back({output_dtypes[output_dtype_idx++]}); - } - } - } - return result; -} - -paddle::Tensor BuildEmptyDistPaddleTensor( - const phi::distributed::ProcessMesh& process_mesh, - const phi::DDim& dims, - phi::DataType dtype) { - paddle::Tensor empty_tensor; - phi::DenseTensorMeta meta; - meta.dims = dims; - meta.dtype = dtype; - - auto dist_attr = phi::distributed::TensorDistAttr(phi::vectorize(dims)); - dist_attr.set_process_mesh(process_mesh); - - auto dist_t = std::make_shared( - std::make_shared( - std::make_shared( - nullptr, 0, phi::distributed::GetDefaultPlace()), - meta), - dist_attr); - empty_tensor.set_impl(dist_t); - return empty_tensor; -} - -void run_custom_op_kernel( - paddle::CustomOpKernelContext& ctx, // NOLINT - const std::vector& vec_map, - const std::vector& inputs, - const std::vector& outputs, - const std::unordered_map& inplace_map) { - bool run_auto_parallel = false; - bool rank_is_in_current_mesh = true; - phi::distributed::ProcessMesh current_process_mesh; - std::vector* all_inputs = ctx.AllMutableInput(); - -#ifdef PADDLE_WITH_DISTRIBUTE - std::vector x = *all_inputs; - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, x)) { - ConvertAllInputsToDistTensor(mesh, x); - } - - run_auto_parallel = paddle::experimental::AllInputsAreDistTensor(x); - rank_is_in_current_mesh = true; - if (run_auto_parallel) { - for (size_t i = 0; i < all_inputs->size(); ++i) { - PADDLE_ENFORCE_EQ( - all_inputs->at(i).initialized() && all_inputs->at(i).is_gpu(), - true, - phi::errors::InvalidArgument( - "The custom op's input tensor must be initialized " - "tensor on gpu, in AutoParallel mode.")); - } - - auto mesh = - std::static_pointer_cast(x.at(0).impl()) - ->dist_attr() - .process_mesh(); - rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); - - std::vector input_x(x.size()); - for (size_t i = 0; i < input_x.size(); ++i) { - input_x[i] = x.at(i).impl().get(); - } - - auto meta_dist_input_x = paddle::experimental::MakeDistMetaTensor(input_x); - auto spmd_info = - phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x); - current_process_mesh = spmd_info.first[0].process_mesh(); - - if (rank_is_in_current_mesh) { - auto* dev_ctx = static_cast( - phi::DeviceContextPool::Instance().Get(x.at(0).place())); - auto dist_input_x = - paddle::experimental::ReshardApiInputToReplicatedKernelInput( - dev_ctx, x, spmd_info.first); - for (size_t i = 0; i < x.size(); ++i) { - all_inputs->at(i).set_impl(std::make_shared( - *(dist_input_x[i]->unsafe_mutable_value()))); - } - } else { - auto& infer_shape_func = - paddle::OpMetaInfoHelper::GetInferShapeFn(vec_map[0]); - auto& infer_dtype_func = - paddle::OpMetaInfoHelper::GetInferDtypeFn(vec_map[0]); - - std::vector> out_dims; - if (infer_shape_func) { - out_dims = RunInferShapeFunc( - ctx, infer_shape_func, inputs, outputs, inplace_map); - } else { - out_dims = RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); - } - - std::vector> out_dtypes; - if (infer_dtype_func) { - out_dtypes = RunInferDtypeFunc( - ctx, infer_dtype_func, inputs, outputs, inplace_map); - } else { - out_dtypes = - RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); - } - - PADDLE_ENFORCE_EQ( - out_dims.size(), - out_dtypes.size(), - phi::errors::InvalidArgument("custome op infer_shape and infer_dtype " - "must have the same output size.")); - - for (size_t i = 0; i < out_dims.size(); ++i) { - const auto& out_dim = out_dims.at(i); - const auto& out_dtype = out_dtypes.at(i); - PADDLE_ENFORCE_EQ( - out_dim.size(), - out_dtype.size(), - phi::errors::InvalidArgument( - "custome op infer_shape result[%d] and infer_dtype result[%d] " - "must have the same output size.", - i, - i)); - if (out_dim.size() == 0) { - ctx.EmplaceBackOutput(std::move(paddle::Tensor())); - } else if (out_dim.size() == 1) { - ctx.EmplaceBackOutput(std::move(BuildEmptyDistPaddleTensor( - current_process_mesh, out_dim[0], out_dtype[0]))); - } else { - std::vector out_tensors; - out_tensors.reserve(out_dim.size()); - for (size_t j = 0; j < out_dim.size(); ++j) { - out_tensors.emplace_back(BuildEmptyDistPaddleTensor( - current_process_mesh, out_dim[j], out_dtype[j])); - } - ctx.EmplaceBackOutputs(out_tensors); - } - } - return; - } - } -#endif - - for (size_t i = 0; i < all_inputs->size(); ++i) { - auto& tensor = all_inputs->at(i); - if (tensor.initialized() && tensor.is_dense_tensor() && - !std::dynamic_pointer_cast(tensor.impl()) - ->meta() - .is_contiguous()) { - tensor.set_impl(std::make_shared( - std::move(paddle::experimental::Trans2Contiguous( - *(std::dynamic_pointer_cast(tensor.impl())))))); - } - } - - const auto& inplace_reverse_idx_map = ctx.GetInplaceReverseIndexMap(); - for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) { - const auto& output = outputs.at(out_idx); - // inplace special case - if (inplace_reverse_idx_map.find(out_idx) != - inplace_reverse_idx_map.end()) { - size_t in_idx = inplace_reverse_idx_map.at(out_idx); - const auto& input_range = ctx.InputRangeAt(in_idx); - const auto& input_tensor = ctx.InputAt(input_range.first); - // inplace optional [Tensor or vector], un-initialized tensor. - if (paddle::framework::detail::IsOptionalVar(output) && - !input_tensor.initialized()) { - VLOG(7) << "Custom operator add output " << output - << " to CustomOpKernelContext. Add un-initialized tensor " - "because the inplace optional input is None"; - ctx.EmplaceBackOutput(std::move(paddle::Tensor())); - continue; - } - /// inplace vector, initialized tensor. - if (paddle::framework::detail::IsDuplicableVar(output)) { - std::vector empty_tensors; - size_t vector_size = input_range.second - input_range.first; - empty_tensors.resize(vector_size); - for (size_t i = 0; i < vector_size; ++i) { - empty_tensors[i] = InitializedEmptyTensor(); - } - VLOG(7) << "Custom operator add output " << output - << " to CustomOpKernelContext. Add vector size = " - << empty_tensors.size(); - ctx.EmplaceBackOutputs(empty_tensors); - continue; - } - } - VLOG(7) << "Custom operator add output " << output - << " to CustomOpKernelContext. Add initialized Tensor because " - "using general or inplace mechanism"; - // general Tensor or inplace Tensor, initialized tensor. - ctx.EmplaceBackOutput(std::move(InitializedEmptyTensor())); - } - - // handle inplace map - ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); - VLOG(7) << "Begin run Kernel of Custom Op"; - (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx); - ctx.AssignInplaceOutputs(); - -#ifdef PADDLE_WITH_DISTRIBUTE - if (run_auto_parallel) { - std::vector* output_all = ctx.AllMutableOutput(); - for (size_t i = 0; i < output_all->size(); ++i) { - auto& tensor = output_all->at(i); - phi::distributed::TensorDistAttr dist_attr = - phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); - dist_attr.set_process_mesh(current_process_mesh); - auto dist_t = std::make_shared( - std::dynamic_pointer_cast(tensor.impl()), - dist_attr); - tensor.set_impl(dist_t); - } - std::vector* input_all = ctx.AllMutableInput(); - for (size_t i = 0; i < input_all->size(); ++i) { - auto& tensor = input_all->at(i); - phi::distributed::TensorDistAttr dist_attr = - phi::distributed::TensorDistAttr(phi::vectorize(tensor.dims())); - dist_attr.set_process_mesh(current_process_mesh); - auto dist_t = std::make_shared( - std::dynamic_pointer_cast(tensor.impl()), - dist_attr); - tensor.set_impl(dist_t); - } - } -#endif -} - static PyObject* eager_api_run_custom_op(PyObject* self, PyObject* args, PyObject* kwargs) { @@ -1062,7 +541,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, const auto& attrs = paddle::OpMetaInfoHelper::GetAttrs(vec_map[0]); const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(vec_map[0]); const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[0]); - ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); + for (size_t i = 0; i < inputs.size(); ++i) { const auto& input = inputs.at(i); // Parse op_type first, so that use i + 1 @@ -1141,8 +620,49 @@ static PyObject* eager_api_run_custom_op(PyObject* self, { eager_gil_scoped_release guard; + ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); + const auto& inplace_reverse_idx_map = ctx.GetInplaceReverseIndexMap(); + for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) { + const auto& output = outputs.at(out_idx); + // inplace special case + if (inplace_reverse_idx_map.find(out_idx) != + inplace_reverse_idx_map.end()) { + size_t in_idx = inplace_reverse_idx_map.at(out_idx); + const auto& input_range = ctx.InputRangeAt(in_idx); + const auto& input_tensor = ctx.InputAt(input_range.first); + // inplace optional [Tensor or vector], un-initialized tensor. + if (paddle::framework::detail::IsOptionalVar(output) && + !input_tensor.initialized()) { + VLOG(7) << "Custom operator add output " << output + << " to CustomOpKernelContext. Add un-initialized tensor " + "because the inplace optional input is None"; + ctx.EmplaceBackOutput(std::move(paddle::Tensor())); + continue; + } + /// inplace vector, initialized tensor. + if (paddle::framework::detail::IsDuplicableVar(output)) { + std::vector empty_tensors; + size_t vector_size = input_range.second - input_range.first; + empty_tensors.resize(vector_size); + for (size_t i = 0; i < vector_size; ++i) { + empty_tensors[i] = InitializedEmptyTensor(); + } + VLOG(7) << "Custom operator add output " << output + << " to CustomOpKernelContext. Add vector size = " + << empty_tensors.size(); + ctx.EmplaceBackOutputs(empty_tensors); + continue; + } + } + VLOG(7) << "Custom operator add output " << output + << " to CustomOpKernelContext. Add initialized Tensor because " + "using general or inplace mechanism"; + // general Tensor or inplace Tensor, initialized tensor. + ctx.EmplaceBackOutput(std::move(InitializedEmptyTensor())); + } + VLOG(7) << "Run Kernel of Custom Op: " << op_type; - run_custom_op_kernel(ctx, vec_map, inputs, outputs, inplace_map); + egr::run_custom_op_impl(vec_map[0], ctx); // handle optional None output when construct backward graph for (size_t i = 0; i < ctx.OutputRange().size(); i++) { diff --git a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py index b7150292403e39..4b3c51328eb509 100644 --- a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py +++ b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py @@ -91,10 +91,10 @@ def test_body(self, x_shape, x_specs): dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr) dist_x.stop_gradient = False - x = paddle.add(x, x) - dist_x = paddle.add(dist_x, dist_x) - out = custom_ops.custom_relu(x) - dist_out = custom_ops.custom_relu(dist_x) + y = paddle.add(x, x) + dist_y = paddle.add(dist_x, dist_x) + out = custom_ops.custom_relu(y) + dist_out = custom_ops.custom_relu(dist_y) out.stop_gradient = False dist_out.stop_gradient = False From 74908389668aac8734cea014776278cfdd102bac Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 1 Nov 2023 11:05:37 +0000 Subject: [PATCH 07/23] refine --- .../custom_operator_run_kernel_impl.cc | 64 +++++++++++++------ 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc index d022bac23abddd..482a58a6426a51 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc @@ -366,22 +366,51 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, #ifdef PADDLE_WITH_DISTRIBUTE std::vector x = *all_inputs; const phi::distributed::ProcessMesh* mesh = nullptr; - if (paddle::pybind::InputsContainDistTensor(&mesh, x)) { - paddle::pybind::ConvertAllInputsToDistTensor(mesh, x); + for (auto& input : x) { + if (input.is_dist_tensor()) { + mesh = &( + std::dynamic_pointer_cast(input.impl()) + ->dist_attr() + .process_mesh()); + break; + } + } + + if (mesh) { + for (auto& input : x) { + if (input.is_dist_tensor()) { + PADDLE_ENFORCE_EQ( + std::dynamic_pointer_cast( + input.impl()) + ->dist_attr() + .process_mesh(), + *mesh, + platform::errors::InvalidArgument( + "Input %s has different mesh. However all inputs should " + "have the same mesh.", + input.name())); + return; + } else { + PADDLE_ENFORCE_EQ( + phi::DenseTensor::classof(input.impl().get()), + true, + platform::errors::InvalidArgument("Failed to convert input %s impl " + "to phi::distributed::DistTensor " + "as it's not phi::DenseTensor.", + input.name())); + phi::distributed::TensorDistAttr dist_attr( + phi::vectorize(input.impl()->dims())); + dist_attr.set_process_mesh(*mesh); + auto dense_t = std::static_pointer_cast(input.impl()); + input.set_impl( + std::make_shared(dense_t, dist_attr)); + } + } } run_auto_parallel = paddle::experimental::AllInputsAreDistTensor(x); rank_is_in_current_mesh = true; if (run_auto_parallel) { - for (size_t i = 0; i < all_inputs->size(); ++i) { - PADDLE_ENFORCE_EQ( - all_inputs->at(i).initialized() && all_inputs->at(i).is_gpu(), - true, - phi::errors::InvalidArgument( - "The custom op's input tensor must be initialized " - "tensor on gpu, in AutoParallel mode.")); - } - auto mesh = std::static_pointer_cast(x.at(0).impl()) ->dist_attr() @@ -399,8 +428,7 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, current_process_mesh = spmd_info.first[0].process_mesh(); if (rank_is_in_current_mesh) { - auto* dev_ctx = static_cast( - phi::DeviceContextPool::Instance().Get(x.at(0).place())); + auto* dev_ctx = phi::DeviceContextPool::Instance().Get(x.at(0).place()); auto dist_input_x = paddle::experimental::ReshardApiInputToReplicatedKernelInput( dev_ctx, x, spmd_info.first); @@ -443,11 +471,11 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, PADDLE_ENFORCE_EQ( out_dim.size(), out_dtype.size(), - phi::errors::InvalidArgument( - "custome op infer_shape result[%d] and infer_dtype result[%d] " - "must have the same output size.", - i, - i)); + phi::errors::InvalidArgument("custome op infer_shape result[%d] " + "and infer_dtype result[%d] " + "must have the same output size.", + i, + i)); if (out_dim.size() == 0) { ctx.EmplaceBackOutput(std::move(paddle::Tensor())); } else if (out_dim.size() == 1) { From da6a898438fe586eecc43122ef540f3323c27ff8 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 1 Nov 2023 11:06:11 +0000 Subject: [PATCH 08/23] refine --- .../eager/custom_operator/custom_operator_run_kernel_impl.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc index 482a58a6426a51..5c47363bd7a7ec 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc @@ -17,7 +17,6 @@ #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator_utils.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/pybind/eager_utils.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" From b8f5a485967468ba34649a606075d17532b5c1cb Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 1 Nov 2023 11:07:36 +0000 Subject: [PATCH 09/23] refine --- .../custom_operator/custom_operator_run_kernel_impl.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc index 5c47363bd7a7ec..527f3c28a65bd8 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc @@ -384,7 +384,7 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, ->dist_attr() .process_mesh(), *mesh, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Input %s has different mesh. However all inputs should " "have the same mesh.", input.name())); @@ -393,10 +393,10 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, PADDLE_ENFORCE_EQ( phi::DenseTensor::classof(input.impl().get()), true, - platform::errors::InvalidArgument("Failed to convert input %s impl " - "to phi::distributed::DistTensor " - "as it's not phi::DenseTensor.", - input.name())); + phi::errors::InvalidArgument("Failed to convert input %s impl " + "to phi::distributed::DistTensor " + "as it's not phi::DenseTensor.", + input.name())); phi::distributed::TensorDistAttr dist_attr( phi::vectorize(input.impl()->dims())); dist_attr.set_process_mesh(*mesh); From 40d20704cf0a947f5063065728de77f8757c0e7a Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 1 Nov 2023 11:26:03 +0000 Subject: [PATCH 10/23] refine --- .../custom_operator/custom_operator_run_kernel_impl.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc index 527f3c28a65bd8..187c19be511f6e 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc @@ -328,6 +328,7 @@ static std::vector> RunInferDtypeFunc( return result; } +#ifdef PADDLE_WITH_DISTRIBUTE paddle::Tensor BuildEmptyDistPaddleTensor( const phi::distributed::ProcessMesh& process_mesh, const phi::DDim& dims, @@ -349,6 +350,7 @@ paddle::Tensor BuildEmptyDistPaddleTensor( empty_tensor.set_impl(dist_t); return empty_tensor; } +#endif void run_custom_op_impl(paddle::OpMetaInfo op_info, paddle::CustomOpKernelContext& ctx) { // NOLINT @@ -357,12 +359,13 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(op_info); ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); + std::vector* all_inputs = ctx.AllMutableInput(); + +#ifdef PADDLE_WITH_DISTRIBUTE bool run_auto_parallel = false; bool rank_is_in_current_mesh = true; phi::distributed::ProcessMesh current_process_mesh; - std::vector* all_inputs = ctx.AllMutableInput(); -#ifdef PADDLE_WITH_DISTRIBUTE std::vector x = *all_inputs; const phi::distributed::ProcessMesh* mesh = nullptr; for (auto& input : x) { From 3f38300f39612a14c6c39cbc0c7008b446c706e8 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 2 Nov 2023 02:14:06 +0000 Subject: [PATCH 11/23] refine --- .../custom_operator/custom_operator_run_kernel_impl.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc index 187c19be511f6e..4782ebf9af7b02 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc @@ -426,8 +426,12 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, auto meta_dist_input_x = paddle::experimental::MakeDistMetaTensor(input_x); auto spmd_info = - phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x); - current_process_mesh = spmd_info.first[0].process_mesh(); + phi::distributed::VariadicReplicatedInferSpmdDynamic(meta_dist_input_x); + current_process_mesh = + paddle::holds_alternative( + spmd_info.first[0]) + ? paddle::get<0>(spmd_info.first[0]).process_mesh() + : paddle::get<1>(spmd_info.first[0]).at(0).process_mesh(); if (rank_is_in_current_mesh) { auto* dev_ctx = phi::DeviceContextPool::Instance().Get(x.at(0).place()); From b0c889148bd9ca35a6516fa29aaaa0b67dfa782a Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 2 Nov 2023 02:16:20 +0000 Subject: [PATCH 12/23] refine --- paddle/fluid/eager/custom_operator/CMakeLists.txt | 6 +++--- paddle/fluid/eager/custom_operator/custom_operator_node.cc | 2 +- ...operator_run_kernel_impl.cc => custom_operator_utils.cc} | 2 +- ...m_operator_run_kernel_impl.h => custom_operator_utils.h} | 0 paddle/fluid/pybind/eager_functions.cc | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) rename paddle/fluid/eager/custom_operator/{custom_operator_run_kernel_impl.cc => custom_operator_utils.cc} (99%) rename paddle/fluid/eager/custom_operator/{custom_operator_run_kernel_impl.h => custom_operator_utils.h} (100%) diff --git a/paddle/fluid/eager/custom_operator/CMakeLists.txt b/paddle/fluid/eager/custom_operator/CMakeLists.txt index 88de909ec4965c..a74ba2dc8c6287 100644 --- a/paddle/fluid/eager/custom_operator/CMakeLists.txt +++ b/paddle/fluid/eager/custom_operator/CMakeLists.txt @@ -1,9 +1,9 @@ cc_library( custom_operator_node SRCS custom_operator_node.cc - DEPS phi grad_node_info custom_operator utils custom_operator_run_kernel_impl) + DEPS phi grad_node_info custom_operator utils custom_operator_utils) cc_library( - custom_operator_run_kernel_impl - SRCS custom_operator_run_kernel_impl.cc + custom_operator_utils + SRCS custom_operator_utils.cc DEPS phi grad_node_info custom_operator utils) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index 089a0321f175d4..eedcdecc1d850a 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/eager/custom_operator/custom_operator_node.h" -#include "paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h" +#include "paddle/fluid/eager/custom_operator/custom_operator_utils.h" #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator_utils.h" #include "paddle/fluid/platform/profiler/event_tracing.h" diff --git a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc similarity index 99% rename from paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc rename to paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 4782ebf9af7b02..2c5c0f15f20789 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h" +#include "paddle/fluid/eager/custom_operator/custom_operator_utils.h" #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator_utils.h" diff --git a/paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h b/paddle/fluid/eager/custom_operator/custom_operator_utils.h similarity index 100% rename from paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h rename to paddle/fluid/eager/custom_operator/custom_operator_utils.h diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index e170c335b30051..3afa624404c28a 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -62,7 +62,7 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/pybind/cuda_streams_py.h" #endif -#include "paddle/fluid/eager/custom_operator/custom_operator_run_kernel_impl.h" +#include "paddle/fluid/eager/custom_operator/custom_operator_utils.h" #include "paddle/phi/api/include/operants_manager.h" #include "paddle/phi/api/include/tensor_operants.h" #include "paddle/phi/api/lib/data_transform.h" From b9fe1261dcf34e03458adc7e933ab524975733eb Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 2 Nov 2023 08:28:05 +0000 Subject: [PATCH 13/23] refine --- .../custom_operator/custom_operator_utils.cc | 56 ++++++++++++------- paddle/fluid/pybind/eager_functions.cc | 3 +- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 2c5c0f15f20789..723b140bf01a33 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -108,7 +108,7 @@ static std::vector> RunInferShapeFunc( VLOG(3) << "Custom Operator: InferShape - get input ddim."; for (size_t i = 0; i < ctx.InputRange().size(); ++i) { const auto& input_pair = ctx.InputRangeAt(i); - if (input_pair.first == input_pair.second) { + if (input_pair.first == input_pair.second - 1) { input_shapes.emplace_back( std::move(ctx.InputAt(input_pair.first).shape())); } else { @@ -256,7 +256,7 @@ static std::vector> RunInferDtypeFunc( VLOG(3) << "Custom Operator: InferDtype - get input dtype."; for (size_t i = 0; i < ctx.InputRange().size(); ++i) { const auto& input_pair = ctx.InputRangeAt(i); - if (input_pair.first == input_pair.second) { + if (input_pair.first == input_pair.second - 1) { input_dtypes.emplace_back( std::move(ctx.InputAt(input_pair.first).dtype())); } else { @@ -391,7 +391,6 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, "Input %s has different mesh. However all inputs should " "have the same mesh.", input.name())); - return; } else { PADDLE_ENFORCE_EQ( phi::DenseTensor::classof(input.impl().get()), @@ -437,7 +436,7 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, auto* dev_ctx = phi::DeviceContextPool::Instance().Get(x.at(0).place()); auto dist_input_x = paddle::experimental::ReshardApiInputToReplicatedKernelInput( - dev_ctx, x, spmd_info.first); + dev_ctx, x, spmd_info.first[0]); for (size_t i = 0; i < x.size(); ++i) { all_inputs->at(i).set_impl(std::make_shared( *(dist_input_x[i]->unsafe_mutable_value()))); @@ -467,34 +466,49 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, PADDLE_ENFORCE_EQ( out_dims.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_shape return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dims.size())); + + PADDLE_ENFORCE_EQ( out_dtypes.size(), - phi::errors::InvalidArgument("custome op infer_shape and infer_dtype " - "must have the same output size.")); + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_dtype return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dtypes.size())); for (size_t i = 0; i < out_dims.size(); ++i) { const auto& out_dim = out_dims.at(i); const auto& out_dtype = out_dtypes.at(i); + const auto& pair = ctx.OutputRangeAt(i); PADDLE_ENFORCE_EQ( out_dim.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dim.size())); + PADDLE_ENFORCE_EQ( out_dtype.size(), - phi::errors::InvalidArgument("custome op infer_shape result[%d] " - "and infer_dtype result[%d] " - "must have the same output size.", + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", i, - i)); - if (out_dim.size() == 0) { - ctx.EmplaceBackOutput(std::move(paddle::Tensor())); - } else if (out_dim.size() == 1) { - ctx.EmplaceBackOutput(std::move(BuildEmptyDistPaddleTensor( - current_process_mesh, out_dim[0], out_dtype[0]))); + pair.second - pair.first, + out_dtype.size())); + + if (out_dim.size() == 1) { + *(ctx.MutableOutputAt(pair.first)) = BuildEmptyDistPaddleTensor( + current_process_mesh, out_dim[0], out_dtype[0]); } else { - std::vector out_tensors; - out_tensors.reserve(out_dim.size()); - for (size_t j = 0; j < out_dim.size(); ++j) { - out_tensors.emplace_back(BuildEmptyDistPaddleTensor( - current_process_mesh, out_dim[j], out_dtype[j])); + for (size_t j = input_pair.first; j < input_pair.second; j++) { + *(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor( + current_process_mesh, out_dim[0], out_dtype[0]); } - ctx.EmplaceBackOutputs(out_tensors); } } return; diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 3afa624404c28a..b1f3d6cb8b12c4 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -671,7 +671,8 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ctx.MutableOutputAt(ctx.OutputRangeAt(i).first); if (!out_tensor->initialized()) { PADDLE_ENFORCE( - paddle::framework::detail::IsOptionalVar(outputs.at(i)), + paddle::framework::detail::IsOptionalVar(outputs.at(i)) || + out_tensor->is_dist_tensor(), phi::errors::InvalidArgument( "Custom operator's %d-th output is not initialized. " "Please check your implementation again. If you are " From e069ba3a6ca1d0e25adf54d7a4e679c7f006b825 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 2 Nov 2023 08:30:19 +0000 Subject: [PATCH 14/23] refine --- paddle/fluid/eager/custom_operator/custom_operator_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 723b140bf01a33..041faf1ce56ddd 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -505,7 +505,7 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, *(ctx.MutableOutputAt(pair.first)) = BuildEmptyDistPaddleTensor( current_process_mesh, out_dim[0], out_dtype[0]); } else { - for (size_t j = input_pair.first; j < input_pair.second; j++) { + for (size_t j = pair.first; j < pair.second; j++) { *(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor( current_process_mesh, out_dim[0], out_dtype[0]); } From 64fb686d9dc019962f0b474baf9e7a02a6061844 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 2 Nov 2023 09:17:24 +0000 Subject: [PATCH 15/23] refine --- .../custom_operator/custom_operator_utils.cc | 2 + ...mi_auto_parallel_simple_net_custom_relu.py | 138 ++++++++++++++++++ ...test_semi_auto_parallel_single_strategy.py | 10 ++ 3 files changed, 150 insertions(+) create mode 100644 test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 041faf1ce56ddd..ffb66362229ef2 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/eager/custom_operator/custom_operator_utils.h" +#include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator_utils.h" #include "paddle/fluid/platform/enforce.h" @@ -348,6 +349,7 @@ paddle::Tensor BuildEmptyDistPaddleTensor( meta), dist_attr); empty_tensor.set_impl(dist_t); + empty_tensor.set_autograd_meta(std::make_shared()); return empty_tensor; } #endif diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py new file mode 100644 index 00000000000000..3702053cb87a5d --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from site import getsitepackages + +from semi_auto_parallel_simple_net import ( + PPDemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.utils.cpp_extension import get_build_directory, load +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS, run_cmd + +BATCH_SIZE = 16 +BATCH_NUM = 4 +IMAGE_SIZE = 784 +CLASS_NUM = 10 + +# Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI. +# `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find +# paddle include directory. Because the following path is generated after installing +# PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. +paddle_includes = [] +for site_packages_path in getsitepackages(): + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include') + ) + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include', 'third_party') + ) + +# Test for extra compile args +extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] +extra_nvcc_args = ['-O3'] + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = f'{get_build_directory()}\\dist_custom_relu\\dist_custom_relu.pyd' +if os.name == 'nt' and os.path.isfile(file): + cmd = f'del {file}' + run_cmd(cmd, True) + +if os.name == 'nt': + test_include = "..\\python\\paddle\\base\\tests\\auto_parallel" +else: + test_include = "../python/paddle/base/tests/auto_parallel" +paddle_includes.append(test_include) + +custom_ops = load( + name='dist_custom_relu_jit', + sources=[ + 'dist_custom_relu_op.cc', + 'dist_custom_relu_op_dup.cc', + 'dist_custom_relu_op.cu', + ], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True, +) + + +class PPDemoNetCustomRelu(PPDemoNet): + def __init__(self, mesh0, mesh1, param_suffix=""): + super().__init__(mesh0, mesh1, param_suffix) + + def forward(self, x): + out = F.linear(x, self.w0) + out = custom_ops.custom_relu(out) + # out = F.relu(out) + out = dist.reshard(out, dist_attr=self.replicate_dist_attr1) + out = F.linear(out, self.w1) + return out + + +class TestSimpleNetWithCustomReluForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._pp_mesh0 = dist.ProcessMesh([0], dim_names=["x"]) + self._pp_mesh1 = dist.ProcessMesh([1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + + def run_dynamic_custom_relu(self, layer, shard_input=False): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(self.image) + if shard_input: + image = dist.shard_tensor( + image, + dist_attr=dist.DistAttr( + mesh=self._mesh, sharding_specs=['x', None] + ), + ) + out = layer(image) + + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + + # loss.backward() + + def test_demo_net(self): + mp_layer = dist.shard_layer( + PPDemoNetCustomRelu(self._pp_mesh0, self._pp_mesh1), + self._mesh, + self.shard_fn, + ) + self.run_dynamic_custom_relu(mp_layer) + + def run_test_case(self): + self.test_demo_net() + + +if __name__ == "__main__": + TestSimpleNetWithCustomReluForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py index 27bff7eda64faa..699e0206b7f3b5 100644 --- a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py +++ b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py @@ -112,6 +112,16 @@ def test_simple_net_zero_grads(self): user_defined_envs=envs, ) + def test_simple_net_custom_relu(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_custom_relu.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() From 635fe9be236d1ee17778a1cfe876298ad692893c Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 2 Nov 2023 10:55:22 +0000 Subject: [PATCH 16/23] refine --- .../semi_auto_parallel_simple_net_custom_relu.py | 5 ----- .../auto_parallel/test_semi_auto_parallel_single_strategy.py | 1 + 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py index 3702053cb87a5d..40057195b01388 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py @@ -27,11 +27,6 @@ from paddle.utils.cpp_extension import get_build_directory, load from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS, run_cmd -BATCH_SIZE = 16 -BATCH_NUM = 4 -IMAGE_SIZE = 784 -CLASS_NUM = 10 - # Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI. # `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find # paddle include directory. Because the following path is generated after installing diff --git a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py index 699e0206b7f3b5..d4d1418e831eba 100644 --- a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py +++ b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py @@ -113,6 +113,7 @@ def test_simple_net_zero_grads(self): ) def test_simple_net_custom_relu(self): + self._changeable_envs = {"backend": ["gpu"]} envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs ) From 322b1b39a3c18abab8bd601957135d94b7c2064c Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 3 Nov 2023 01:42:15 +0000 Subject: [PATCH 17/23] refine --- test/auto_parallel/CMakeLists.txt | 2 +- .../semi_auto_parallel_simple_net_custom_relu.py | 2 +- test/auto_parallel/test_semi_auto_parallel_basic.py | 10 ++++++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index aa1212c9ecc11d..e0a868c4368e57 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -122,7 +122,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_semi_auto_parallel_single_strategy MODULES test_semi_auto_parallel_single_strategy) set_tests_properties(test_semi_auto_parallel_single_strategy - PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300) py_test_modules(test_semi_auto_parallel_hybrid_strategy MODULES test_semi_auto_parallel_hybrid_strategy) set_tests_properties(test_semi_auto_parallel_hybrid_strategy diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py index 40057195b01388..683776a626ab6c 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py @@ -115,7 +115,7 @@ def run_dynamic_custom_relu(self, layer, shard_input=False): label = paddle.to_tensor(self.label) loss = loss_fn(out, label) - # loss.backward() + loss.backward() def test_demo_net(self): mp_layer = dist.shard_layer( diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index b1132a6a3a8dce..2589566cb670ec 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -96,6 +96,16 @@ def test_add_n_api(self): user_defined_envs=envs, ) + def test_custom_relu_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_custom_relu.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() From be791d200e6a8a292600cc893127b5c1c9c9bfff Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 3 Nov 2023 07:10:09 +0000 Subject: [PATCH 18/23] refine --- .../custom_operator/custom_operator_node.cc | 8 +- .../custom_operator/custom_operator_utils.cc | 116 +++++++++++++++++- .../custom_operator/custom_operator_utils.h | 2 + paddle/fluid/pybind/eager_functions.cc | 2 +- ...mi_auto_parallel_simple_net_custom_relu.py | 42 +++++-- 5 files changed, 156 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index eedcdecc1d850a..9b6318c7a43ed3 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -250,7 +250,7 @@ RunCustomOpNode::operator()(paddle::small_vector, } VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad"; - run_custom_op_impl(vec_map[1], ctx); + run_custom_op_impl(vec_map[1], false, false, ctx); for (size_t i = 0; i < ctx.OutputRange().size(); ++i) { auto output_pair = ctx.OutputRangeAt(i); @@ -264,7 +264,9 @@ RunCustomOpNode::operator()(paddle::small_vector, ctx.MutableOutputAt(ctx.OutputRangeAt(i).first); if (!out_tensor->initialized()) { PADDLE_ENFORCE( - paddle::framework::detail::IsOptionalVar(grad_outputs_names.at(i)), + paddle::framework::detail::IsOptionalVar( + grad_outputs_names.at(i)) || + out_tensor->is_dist_tensor(), phi::errors::InvalidArgument( "Custom grad operator's %d-th output is not initialized. " "Please check your implementation again. If you are " @@ -449,7 +451,7 @@ RunCustomOpDoubleGradNode::operator()( } VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad_grad"; - run_custom_op_impl(vec_map[2], ctx); + run_custom_op_impl(vec_map[2], false, true, ctx); for (size_t i = 0; i < ctx.OutputRange().size(); ++i) { auto output_pair = ctx.OutputRangeAt(i); diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index ffb66362229ef2..1581cca6d265e1 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -96,6 +96,59 @@ static std::vector> RunDefaultInferShapeFunc( return result; } +static std::vector> RunDefaultGradInferShapeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& grad_op_inputs, + const std::vector& grad_op_outputs, + bool is_double_grad) { + std::vector> result; + // 1. if forward input exists, gradient's shape is same with forward + // input + // default + // [Suitable for most situations] + // 2. if forward input not exists, and only contains one grad input and + // output, + // use grad input shape as grad output shape + // [Suitable for the situation that forward input is not used as + // backward input] + for (auto& out_name : grad_op_outputs) { + auto fwd_name = paddle::framework::detail::NoGrad(out_name, is_double_grad); + if (paddle::framework::detail::IsDuplicableVar(fwd_name)) { + // Duplicable forward var must as backward input + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the forward " + "input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + std::vector tmp; + for (size_t i = pair.first; i < pair.second; ++i) { + tmp.emplace_back(ctx.InputAt(i).dims()); + } + result.emplace_back(std::move(tmp)); + } else { + if (grad_op_inputs.size() == grad_op_outputs.size()) { + result.push_back({ctx.InputAt(0).dims()}); + } else { + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the " + "forward input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + result.push_back({ctx.InputAt(pair.first).dims()}); + } + } + } + return result; +} + static std::vector> RunInferShapeFunc( const paddle::CustomOpKernelContext& ctx, const paddle::InferShapeFunc& func, @@ -244,6 +297,50 @@ static std::vector> RunDefaultInferDtypeFunc( return result; } +static std::vector> RunDefaultGradInferDtypeFunc( + const paddle::CustomOpKernelContext& ctx, + const std::vector& grad_op_inputs, + const std::vector& grad_op_outputs, + bool is_double_grad) { + std::vector> result; + for (auto& out_name : grad_op_outputs) { + auto fwd_name = paddle::framework::detail::NoGrad(out_name, is_double_grad); + if (paddle::framework::detail::IsDuplicableVar(fwd_name)) { + // Duplicable forward var must as backward input + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the forward " + "input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + std::vector tmp; + for (size_t i = pair.first; i < pair.second; ++i) { + tmp.emplace_back(ctx.InputAt(i).dtype()); + } + result.emplace_back(std::move(tmp)); + } else { + if (grad_op_inputs.size() == grad_op_outputs.size()) { + result.push_back({ctx.InputAt(0).dtype()}); + } else { + auto iter = + std::find(grad_op_inputs.begin(), grad_op_inputs.end(), fwd_name); + PADDLE_ENFORCE_NE( + iter, + grad_op_inputs.end(), + phi::errors::NotFound("Custom grad operator should have the " + "forward input(%s) as backward input", + fwd_name)); + auto pair = ctx.InputRangeAt(iter - grad_op_inputs.begin()); + result.push_back({ctx.InputAt(pair.first).dtype()}); + } + } + } + return result; +} + static std::vector> RunInferDtypeFunc( const paddle::CustomOpKernelContext& ctx, const paddle::InferDtypeFunc& func, @@ -355,6 +452,8 @@ paddle::Tensor BuildEmptyDistPaddleTensor( #endif void run_custom_op_impl(paddle::OpMetaInfo op_info, + bool is_forward, + bool is_double_grad, paddle::CustomOpKernelContext& ctx) { // NOLINT const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); @@ -454,7 +553,13 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, out_dims = RunInferShapeFunc( ctx, infer_shape_func, inputs, outputs, inplace_map); } else { - out_dims = RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); + if (is_forward) { + out_dims = + RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dims = RunDefaultGradInferShapeFunc( + ctx, inputs, outputs, is_double_grad); + } } std::vector> out_dtypes; @@ -462,8 +567,13 @@ void run_custom_op_impl(paddle::OpMetaInfo op_info, out_dtypes = RunInferDtypeFunc( ctx, infer_dtype_func, inputs, outputs, inplace_map); } else { - out_dtypes = - RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); + if (is_forward) { + out_dtypes = + RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dtypes = RunDefaultGradInferDtypeFunc( + ctx, inputs, outputs, is_double_grad); + } } PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.h b/paddle/fluid/eager/custom_operator/custom_operator_utils.h index 4dfac34141aec1..ae15e2c3216f5b 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.h +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.h @@ -18,5 +18,7 @@ namespace egr { void run_custom_op_impl(paddle::OpMetaInfo op_info, + bool is_forward, + bool is_double_grad, paddle::CustomOpKernelContext& ctx); // NOLINT } // namespace egr diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index b1f3d6cb8b12c4..63e59bbcfeedee 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -662,7 +662,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, } VLOG(7) << "Run Kernel of Custom Op: " << op_type; - egr::run_custom_op_impl(vec_map[0], ctx); + egr::run_custom_op_impl(vec_map[0], true, false, ctx); // handle optional None output when construct backward graph for (size_t i = 0; i < ctx.OutputRange().size(); i++) { diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py index 683776a626ab6c..0018e1cd9d87d1 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py @@ -15,10 +15,7 @@ import os from site import getsitepackages -from semi_auto_parallel_simple_net import ( - PPDemoNet, - TestSimpleNetForSemiAutoParallel, -) +from semi_auto_parallel_simple_net import TestSimpleNetForSemiAutoParallel import paddle import paddle.distributed as dist @@ -70,10 +67,41 @@ verbose=True, ) +BATCH_SIZE = 16 +BATCH_NUM = 4 +IMAGE_SIZE = 784 +CLASS_NUM = 10 + -class PPDemoNetCustomRelu(PPDemoNet): +class PPDemoNet(nn.Layer): def __init__(self, mesh0, mesh1, param_suffix=""): - super().__init__(mesh0, mesh1, param_suffix) + super().__init__() + self.replicate_dist_attr0 = dist.DistAttr( + mesh=mesh0, sharding_specs=[None, None] + ) + self.replicate_dist_attr1 = dist.DistAttr( + mesh=mesh1, sharding_specs=[None, None] + ) + self.w0 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, IMAGE_SIZE], + attr=paddle.framework.ParamAttr( + name="pp_demo_weight_0" + param_suffix, + initializer=paddle.nn.initializer.Uniform(0, 1), + ), + ), + dist_attr=self.replicate_dist_attr0, + ) + self.w1 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, CLASS_NUM], + attr=paddle.framework.ParamAttr( + name="pp_nemo_weight_1" + param_suffix, + initializer=paddle.nn.initializer.Uniform(0, 1), + ), + ), + dist_attr=self.replicate_dist_attr1, + ) def forward(self, x): out = F.linear(x, self.w0) @@ -119,7 +147,7 @@ def run_dynamic_custom_relu(self, layer, shard_input=False): def test_demo_net(self): mp_layer = dist.shard_layer( - PPDemoNetCustomRelu(self._pp_mesh0, self._pp_mesh1), + PPDemoNet(self._pp_mesh0, self._pp_mesh1), self._mesh, self.shard_fn, ) From 903d61e3027ecb61cc66216b97030f93f682b741 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 6 Nov 2023 01:57:39 +0000 Subject: [PATCH 19/23] refine --- test/auto_parallel/dist_custom_relu_op.cc | 268 ------------------ test/auto_parallel/dist_custom_relu_op.cu | 166 ----------- test/auto_parallel/dist_custom_relu_op_dup.cc | 38 --- .../semi_auto_parallel_for_custom_relu.py | 6 +- ...mi_auto_parallel_simple_net_custom_relu.py | 6 +- 5 files changed, 6 insertions(+), 478 deletions(-) delete mode 100644 test/auto_parallel/dist_custom_relu_op.cc delete mode 100644 test/auto_parallel/dist_custom_relu_op.cu delete mode 100644 test/auto_parallel/dist_custom_relu_op_dup.cc diff --git a/test/auto_parallel/dist_custom_relu_op.cc b/test/auto_parallel/dist_custom_relu_op.cc deleted file mode 100644 index 5627bb28b921f4..00000000000000 --- a/test/auto_parallel/dist_custom_relu_op.cc +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/extension.h" - -#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") - -template -void relu_cpu_forward_kernel(const data_t* x_data, - data_t* out_data, - int64_t x_numel) { - PD_CHECK(x_data != nullptr, "x_data is nullptr."); - PD_CHECK(out_data != nullptr, "out_data is nullptr."); - for (int64_t i = 0; i < x_numel; ++i) { - out_data[i] = std::max(static_cast(0.), x_data[i]); - } -} - -template -void relu_cpu_backward_kernel(const data_t* grad_out_data, - const data_t* out_data, - data_t* grad_x_data, - int64_t out_numel) { - for (int64_t i = 0; i < out_numel; ++i) { - grad_x_data[i] = - grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); - } -} - -template -void relu_cpu_double_backward_kernel(const data_t* out_data, - const data_t* ddx_data, - data_t* ddout_data, - int64_t ddout_numel) { - for (int64_t i = 0; i < ddout_numel; ++i) { - ddout_data[i] = - ddx_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); - } -} - -std::vector relu_cpu_forward(const paddle::Tensor& x) { - auto out = paddle::empty_like(x); - - PD_DISPATCH_FLOATING_TYPES( - x.type(), "relu_cpu_forward", ([&] { - relu_cpu_forward_kernel( - x.data(), out.data(), x.numel()); - })); - - return {out}; -} - -std::vector relu_cpu_backward(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out) { - auto grad_x = paddle::empty_like(x); - - PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { - relu_cpu_backward_kernel( - grad_out.data(), - out.data(), - grad_x.data(), - out.size()); - })); - - return {grad_x}; -} - -std::vector relu_cpu_double_backward( - const paddle::Tensor& out, const paddle::Tensor& ddx) { - CHECK_CPU_INPUT(out); - CHECK_CPU_INPUT(ddx); - auto ddout = paddle::empty(out.shape(), out.dtype(), out.place()); - - PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_double_backward", ([&] { - relu_cpu_double_backward_kernel( - out.data(), - ddx.data(), - ddout.mutable_data(out.place()), - ddout.size()); - })); - - return {ddout}; -} - -std::vector relu_cuda_forward(const paddle::Tensor& x); -std::vector relu_cuda_backward(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out); -std::vector relu_cuda_double_backward( - const paddle::Tensor& out, const paddle::Tensor& ddx); - -std::vector ReluForward(const paddle::Tensor& x) { - if (x.is_cpu()) { - return relu_cpu_forward(x); - } else if (x.is_gpu()) { - return relu_cuda_forward(x); - } else { - PD_THROW("Not implemented."); - } -} - -std::vector ReluBackward(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out) { - if (x.is_cpu()) { - return relu_cpu_backward(x, out, grad_out); - } else if (x.is_gpu()) { - return relu_cuda_backward(x, out, grad_out); - } else { - PD_THROW("Not implemented."); - } -} - -std::vector ReluDoubleBackward(const paddle::Tensor& out, - const paddle::Tensor& ddx) { - if (out.is_cpu()) { - return relu_cpu_double_backward(out, ddx); - } else if (out.is_gpu()) { - return relu_cuda_double_backward(out, ddx); - } else { - PD_THROW("Not implemented."); - } -} - -std::vector> ReluDoubleBackwardInferShape( - const std::vector& out_shape, - const std::vector& ddx_shape) { - return {out_shape}; -} - -PD_BUILD_OP(custom_relu) - .Inputs({"X"}) - .Outputs({"Out"}) - .SetKernelFn(PD_KERNEL(ReluForward)); - -PD_BUILD_GRAD_OP(custom_relu) - .Inputs({"X", "Out", paddle::Grad("Out")}) - .Outputs({paddle::Grad("X")}) - .SetKernelFn(PD_KERNEL(ReluBackward)); - -PD_BUILD_DOUBLE_GRAD_OP(custom_relu) - .Inputs({"Out", paddle::Grad(paddle::Grad("X"))}) - .Outputs({paddle::Grad(paddle::Grad("Out"))}) - .SetKernelFn(PD_KERNEL(ReluDoubleBackward)) - .SetInferShapeFn(PD_INFER_SHAPE(ReluDoubleBackwardInferShape)); - -std::vector relu_cpu_backward_without_x( - const paddle::Tensor& out, const paddle::Tensor& grad_out) { - auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place()); - - PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { - relu_cpu_backward_kernel( - grad_out.data(), - out.data(), - grad_x.mutable_data(out.place()), - out.size()); - })); - - return {grad_x}; -} - -std::vector relu_cuda_backward_without_x( - const paddle::Tensor& out, const paddle::Tensor& grad_out); - -std::vector ReluBackwardWithoutX( - const paddle::Tensor& out, const paddle::Tensor& grad_out) { - if (out.is_cpu()) { - return relu_cpu_backward_without_x(out, grad_out); - } else if (out.is_gpu()) { - return relu_cuda_backward_without_x(out, grad_out); - } else { - PD_THROW("Not implemented."); - } -} - -std::vector> ReluBackwardWithoutXInferShape( - const std::vector& out_shape, - const std::vector& grad_out_shape) { - return {out_shape}; -} - -PD_BUILD_OP(custom_relu_no_x_in_backward) - .Inputs({"X"}) - .Outputs({"Out"}) - .SetKernelFn(PD_KERNEL(ReluForward)); - -PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward) - .Inputs({"Out", paddle::Grad("Out")}) - .Outputs({paddle::Grad("X")}) - .SetKernelFn(PD_KERNEL(ReluBackwardWithoutX)) - .SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape)); - -void relu_cpu_forward_out(const paddle::Tensor& x, paddle::Tensor* out) { - out->reshape(x.shape()); - PD_DISPATCH_FLOATING_TYPES( - x.type(), "relu_cpu_forward", ([&] { - relu_cpu_forward_kernel( - x.data(), out->mutable_data(x.place()), x.numel()); - })); -} - -void relu_cpu_backward_out(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out, - paddle::Tensor* grad_x) { - grad_x->reshape(x.shape()); - PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { - relu_cpu_backward_kernel( - grad_out.data(), - out.data(), - grad_x->mutable_data(x.place()), - out.size()); - })); -} - -void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out); -void relu_cuda_backward_out(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out, - paddle::Tensor* grad_x); - -void ReluForwardOut(const paddle::Tensor& x, paddle::Tensor* out) { - if (x.is_cpu()) { - return relu_cpu_forward_out(x, out); - } else if (x.is_gpu()) { - return relu_cuda_forward_out(x, out); - } else { - PD_THROW("Not implemented."); - } -} - -void ReluBackwardOut(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out, - paddle::Tensor* grad_x) { - if (x.is_cpu()) { - return relu_cpu_backward_out(x, out, grad_out, grad_x); - } else if (x.is_gpu()) { - return relu_cuda_backward_out(x, out, grad_out, grad_x); - } else { - PD_THROW("Not implemented."); - } -} - -PD_BUILD_OP(custom_relu_out) - .Inputs({"X"}) - .Outputs({"Out"}) - .SetKernelFn(PD_KERNEL(ReluForwardOut)); - -PD_BUILD_GRAD_OP(custom_relu_out) - .Inputs({"X", "Out", paddle::Grad("Out")}) - .Outputs({paddle::Grad("X")}) - .SetKernelFn(PD_KERNEL(ReluBackwardOut)); diff --git a/test/auto_parallel/dist_custom_relu_op.cu b/test/auto_parallel/dist_custom_relu_op.cu deleted file mode 100644 index 49e5d16938eb80..00000000000000 --- a/test/auto_parallel/dist_custom_relu_op.cu +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" - -#define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") - -template -__global__ void relu_cuda_forward_kernel(const data_t* x, - data_t* y, - int64_t num) { - int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { - y[i] = x[i] > static_cast(0.) ? x[i] : static_cast(0.); - } -} - -template -__global__ void relu_cuda_backward_kernel(const data_t* dy, - const data_t* y, - data_t* dx, - int64_t num) { - int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { - dx[i] = dy[i] * (y[i] > static_cast(0.) ? static_cast(1.) - : static_cast(0.)); - } -} - -template -__global__ void relu_cuda_double_backward_kernel(const data_t* out_data, - const data_t* ddx_data, - data_t* ddout_data, - int64_t num) { - int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { - ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast(0.) - ? static_cast(1.) - : static_cast(0.)); - } -} - -std::vector relu_cuda_forward(const paddle::Tensor& x) { - CHECK_GPU_INPUT(x); - auto out = paddle::empty_like(x); - - PD_CHECK(x.place() == paddle::DefaultGPUPlace()); - - int64_t numel = x.numel(); - int64_t block = 512; - int64_t grid = (numel + block - 1) / block; - PD_DISPATCH_FLOATING_AND_HALF_TYPES( - x.type(), "relu_cuda_forward_kernel", ([&] { - relu_cuda_forward_kernel<<>>( - x.data(), out.data(), numel); - })); - - return {out}; -} - -std::vector relu_cuda_backward(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out) { - CHECK_GPU_INPUT(x); - CHECK_GPU_INPUT(out); - CHECK_GPU_INPUT(grad_out); - auto grad_x = paddle::empty_like(x); - - PD_CHECK(x.place() == paddle::DefaultGPUPlace()); - - int64_t numel = out.numel(); - int64_t block = 512; - int64_t grid = (numel + block - 1) / block; - PD_DISPATCH_FLOATING_AND_HALF_TYPES( - out.type(), "relu_cuda_backward_kernel", ([&] { - relu_cuda_backward_kernel<<>>( - grad_out.data(), - out.data(), - grad_x.mutable_data(x.place()), - numel); - })); - - return {grad_x}; -} - -std::vector relu_cuda_double_backward( - const paddle::Tensor& out, const paddle::Tensor& ddx) { - CHECK_GPU_INPUT(out); - CHECK_GPU_INPUT(ddx); - auto ddout = paddle::empty(out.shape(), out.dtype(), out.place()); - - int64_t numel = out.numel(); - int64_t block = 512; - int64_t grid = (numel + block - 1) / block; - PD_DISPATCH_FLOATING_AND_HALF_TYPES( - out.type(), "relu_cuda_double_backward_kernel", ([&] { - relu_cuda_double_backward_kernel - <<>>( - out.data(), - ddx.data(), - ddout.mutable_data(out.place()), - numel); - })); - - return {ddout}; -} - -std::vector relu_cuda_backward_without_x( - const paddle::Tensor& out, const paddle::Tensor& grad_out) { - auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place()); - - int numel = out.numel(); - int block = 512; - int grid = (numel + block - 1) / block; - PD_DISPATCH_FLOATING_AND_HALF_TYPES( - out.type(), "relu_cuda_backward_kernel", ([&] { - relu_cuda_backward_kernel<<>>( - grad_out.data(), - out.data(), - grad_x.mutable_data(out.place()), - numel); - })); - - return {grad_x}; -} - -void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out) { - int numel = x.numel(); - int block = 512; - int grid = (numel + block - 1) / block; - out->reshape(x.shape()); - PD_DISPATCH_FLOATING_AND_HALF_TYPES( - x.type(), "relu_cuda_forward_kernel", ([&] { - relu_cuda_forward_kernel<<>>( - x.data(), out->mutable_data(x.place()), numel); - })); -} - -void relu_cuda_backward_out(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out, - paddle::Tensor* grad_x) { - int numel = out.numel(); - int block = 512; - int grid = (numel + block - 1) / block; - grad_x->reshape(x.shape()); - PD_DISPATCH_FLOATING_AND_HALF_TYPES( - out.type(), "relu_cuda_backward_kernel", ([&] { - relu_cuda_backward_kernel<<>>( - grad_out.data(), - out.data(), - grad_x->mutable_data(x.place()), - numel); - })); -} diff --git a/test/auto_parallel/dist_custom_relu_op_dup.cc b/test/auto_parallel/dist_custom_relu_op_dup.cc deleted file mode 100644 index 89d14bfa049603..00000000000000 --- a/test/auto_parallel/dist_custom_relu_op_dup.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" - -std::vector relu_cuda_forward(const paddle::Tensor& x); -std::vector relu_cuda_backward(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out); - -std::vector ReluForward(const paddle::Tensor& x); - -std::vector ReluBackward(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out); - -// Reuse codes in `custom_relu_op.cc/cu` to register another custom operator -// to test jointly compile multi operators at same time. -PD_BUILD_OP(custom_relu_dup) - .Inputs({"X"}) - .Outputs({"Out"}) - .SetKernelFn(PD_KERNEL(ReluForward)); - -PD_BUILD_GRAD_OP(custom_relu_dup) - .Inputs({"X", "Out", paddle::Grad("Out")}) - .Outputs({paddle::Grad("X")}) - .SetKernelFn(PD_KERNEL(ReluBackward)); diff --git a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py index 4b3c51328eb509..07496ec07e5062 100644 --- a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py +++ b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py @@ -55,9 +55,9 @@ custom_ops = load( name='dist_custom_relu_jit', sources=[ - 'dist_custom_relu_op.cc', - 'dist_custom_relu_op_dup.cc', - 'dist_custom_relu_op.cu', + '../custom_op/custom_relu_op.cc', + '../custom_op/custom_relu_op_dup.cc', + '../custom_op/custom_relu_op.cu', ], extra_include_paths=paddle_includes, # add for Coverage CI extra_cxx_cflags=extra_cc_args, # test for cc flags diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py index 0018e1cd9d87d1..ef8ff6e004c454 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py @@ -57,9 +57,9 @@ custom_ops = load( name='dist_custom_relu_jit', sources=[ - 'dist_custom_relu_op.cc', - 'dist_custom_relu_op_dup.cc', - 'dist_custom_relu_op.cu', + '../custom_op/custom_relu_op.cc', + '../custom_op/custom_relu_op_dup.cc', + '../custom_op/custom_relu_op.cu', ], extra_include_paths=paddle_includes, # add for Coverage CI extra_cxx_cflags=extra_cc_args, # test for cc flags From ea4aa2f10ac85614cdc27c80eff391e3659ede9a Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 6 Nov 2023 02:15:15 +0000 Subject: [PATCH 20/23] refine --- paddle/fluid/eager/custom_operator/custom_operator_utils.cc | 2 +- paddle/fluid/eager/custom_operator/custom_operator_utils.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 1581cca6d265e1..20d84c8bea4dbc 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -451,7 +451,7 @@ paddle::Tensor BuildEmptyDistPaddleTensor( } #endif -void run_custom_op_impl(paddle::OpMetaInfo op_info, +void run_custom_op_impl(const paddle::OpMetaInfo& op_info, bool is_forward, bool is_double_grad, paddle::CustomOpKernelContext& ctx) { // NOLINT diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.h b/paddle/fluid/eager/custom_operator/custom_operator_utils.h index ae15e2c3216f5b..ac2dec37f3d34c 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.h +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.h @@ -17,7 +17,7 @@ #include "paddle/phi/api/ext/op_meta_info.h" namespace egr { -void run_custom_op_impl(paddle::OpMetaInfo op_info, +void run_custom_op_impl(const paddle::OpMetaInfo& op_info, bool is_forward, bool is_double_grad, paddle::CustomOpKernelContext& ctx); // NOLINT From 963c4a7ae3292b3f00b3dc437b8dc2c2cb81448e Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 6 Nov 2023 06:48:05 +0000 Subject: [PATCH 21/23] refine --- .../custom_operator/custom_operator_utils.cc | 93 ++++++++++++------- 1 file changed, 60 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 20d84c8bea4dbc..d8a4efa96c0b53 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -451,18 +451,12 @@ paddle::Tensor BuildEmptyDistPaddleTensor( } #endif -void run_custom_op_impl(const paddle::OpMetaInfo& op_info, - bool is_forward, - bool is_double_grad, - paddle::CustomOpKernelContext& ctx) { // NOLINT - const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); - const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); - const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(op_info); - ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); - - std::vector* all_inputs = ctx.AllMutableInput(); - #ifdef PADDLE_WITH_DISTRIBUTE +std::tuple PrepareCtxForAutoParallel( + const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx) { // NOLINT bool run_auto_parallel = false; bool rank_is_in_current_mesh = true; phi::distributed::ProcessMesh current_process_mesh; @@ -539,8 +533,8 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, paddle::experimental::ReshardApiInputToReplicatedKernelInput( dev_ctx, x, spmd_info.first[0]); for (size_t i = 0; i < x.size(); ++i) { - all_inputs->at(i).set_impl(std::make_shared( - *(dist_input_x[i]->unsafe_mutable_value()))); + all_inputs->at(i).set_impl( + std::make_shared(dist_input_x[i]->value())); } } else { auto& infer_shape_func = @@ -619,34 +613,22 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, } else { for (size_t j = pair.first; j < pair.second; j++) { *(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor( - current_process_mesh, out_dim[0], out_dtype[0]); + current_process_mesh, out_dim[j], out_dtype[j]); } } } - return; + return std::tuple( + run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); } } +} #endif - for (size_t i = 0; i < all_inputs->size(); ++i) { - auto& tensor = all_inputs->at(i); - if (tensor.initialized() && tensor.is_dense_tensor() && - !std::dynamic_pointer_cast(tensor.impl()) - ->meta() - .is_contiguous()) { - tensor.set_impl(std::make_shared( - std::move(paddle::experimental::Trans2Contiguous( - *(std::dynamic_pointer_cast(tensor.impl())))))); - } - } - - // handle inplace map - ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); - VLOG(7) << "Begin run Kernel of Custom Op"; - (*paddle::OpMetaInfoHelper::GetKernelFn(op_info))(&ctx); - ctx.AssignInplaceOutputs(); - #ifdef PADDLE_WITH_DISTRIBUTE +void TransCtxTensorsToDistTensors( + paddle::CustomOpKernelContext& ctx, // NOLINT + bool run_auto_parallel, + const phi::distributed::ProcessMesh& current_process_mesh) { if (run_auto_parallel) { std::vector* output_all = ctx.AllMutableOutput(); for (size_t i = 0; i < output_all->size(); ++i) { @@ -671,6 +653,51 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, tensor.set_impl(dist_t); } } +} +#endif + +void run_custom_op_impl(const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx) { // NOLINT + const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); + const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(op_info); + ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); + + std::vector* all_inputs = ctx.AllMutableInput(); + +#ifdef PADDLE_WITH_DISTRIBUTE + auto result = + PrepareCtxForAutoParallel(op_info, is_forward, is_double_grad, ctx); + bool run_auto_parallel = std::get<0>(result); + bool rank_is_in_current_mesh = std::get<1>(result); + phi::distributed::ProcessMesh current_process_mesh = std::get<2>(result); + if (!rank_is_in_current_mesh) { + return; + } +#endif + + for (size_t i = 0; i < all_inputs->size(); ++i) { + auto& tensor = all_inputs->at(i); + if (tensor.initialized() && tensor.is_dense_tensor() && + !std::dynamic_pointer_cast(tensor.impl()) + ->meta() + .is_contiguous()) { + tensor.set_impl(std::make_shared( + std::move(paddle::experimental::Trans2Contiguous( + *(std::dynamic_pointer_cast(tensor.impl())))))); + } + } + + // handle inplace map + ctx.UpdatePlainOutputs(inputs, outputs, inplace_map); + VLOG(7) << "Begin run Kernel of Custom Op"; + (*paddle::OpMetaInfoHelper::GetKernelFn(op_info))(&ctx); + ctx.AssignInplaceOutputs(); + +#ifdef PADDLE_WITH_DISTRIBUTE + TransCtxTensorsToDistTensors(ctx, run_auto_parallel, current_process_mesh); #endif } From 5de93145b01689cd5bf36c40c9587e013764606e Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 6 Nov 2023 07:00:24 +0000 Subject: [PATCH 22/23] refine --- .../eager/custom_operator/custom_operator_utils.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index d8a4efa96c0b53..02bef035ca2c8d 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -461,6 +461,11 @@ std::tuple PrepareCtxForAutoParallel( bool rank_is_in_current_mesh = true; phi::distributed::ProcessMesh current_process_mesh; + const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); + const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(op_info); + + std::vector* all_inputs = ctx.AllMutableInput(); std::vector x = *all_inputs; const phi::distributed::ProcessMesh* mesh = nullptr; for (auto& input : x) { @@ -621,6 +626,8 @@ std::tuple PrepareCtxForAutoParallel( run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); } } + return std::tuple( + run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); } #endif @@ -665,8 +672,6 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, const auto& inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(op_info); ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); - std::vector* all_inputs = ctx.AllMutableInput(); - #ifdef PADDLE_WITH_DISTRIBUTE auto result = PrepareCtxForAutoParallel(op_info, is_forward, is_double_grad, ctx); @@ -678,6 +683,7 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, } #endif + std::vector* all_inputs = ctx.AllMutableInput(); for (size_t i = 0; i < all_inputs->size(); ++i) { auto& tensor = all_inputs->at(i); if (tensor.initialized() && tensor.is_dense_tensor() && From 81850168e2b0ac035a5043c9dd01569acaa94b14 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 6 Nov 2023 08:20:41 +0000 Subject: [PATCH 23/23] refine --- paddle/fluid/eager/custom_operator/custom_operator_utils.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 02bef035ca2c8d..7985ef92285d03 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -534,9 +534,8 @@ std::tuple PrepareCtxForAutoParallel( if (rank_is_in_current_mesh) { auto* dev_ctx = phi::DeviceContextPool::Instance().Get(x.at(0).place()); - auto dist_input_x = - paddle::experimental::ReshardApiInputToReplicatedKernelInput( - dev_ctx, x, spmd_info.first[0]); + auto dist_input_x = paddle::experimental::ReshardApiInputToKernelInput( + dev_ctx, x, spmd_info.first[0]); for (size_t i = 0; i < x.size(); ++i) { all_inputs->at(i).set_impl( std::make_shared(dist_input_x[i]->value()));