From 5c5ae2e50c80d93f7988a52ca01e7d8920eb9318 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Mon, 17 Jan 2022 11:46:19 +0000 Subject: [PATCH 01/10] add kernel for c++ api --- cmake/pten_kernel.cmake | 38 ++-- paddle/fluid/operators/cholesky_solve_op.h | 3 +- .../elementwise/elementwise_add_op.h | 2 +- .../elementwise/elementwise_div_op.h | 3 +- .../elementwise/elementwise_mul_op.cu | 4 +- .../elementwise/elementwise_mul_op.h | 4 +- .../operators/elementwise/elementwise_op.h | 9 +- .../elementwise/elementwise_sub_op.h | 4 +- paddle/fluid/operators/lu_op.h | 4 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 4 +- paddle/pten/api/include/kernel_signature.h | 7 - paddle/pten/core/kernel_alias_name.h | 6 - paddle/pten/kernels/cpu/math_kernel.cc | 76 +++---- paddle/pten/kernels/gpu/math_kernel.cu | 77 ++++--- paddle/pten/kernels/math_kernel.cc | 204 ++++++++++++++++++ paddle/pten/kernels/math_kernel.h | 124 ++++++----- .../tests/kernels/test_elementwise_dev_api.cc | 16 +- python/paddle/utils/code_gen/api.yaml | 7 +- 18 files changed, 397 insertions(+), 195 deletions(-) create mode 100644 paddle/pten/kernels/math_kernel.cc diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index bc9fefb58f4527..9d5d4957b4ef42 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -103,36 +103,34 @@ function(kernel_library TARGET) list(LENGTH gpu_srcs gpu_srcs_len) list(LENGTH xpu_srcs xpu_srcs_len) - if (${common_srcs_len} GREATER 0) - # If the kernel has a device independent public implementation, - # we will use this implementation and will not adopt the implementation - # under specific devices + if(${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR + ${xpu_srcs_len} GREATER 0) if (WITH_GPU) - nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() elseif (WITH_ROCM) - hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() else() - cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() endif() else() - # If the kernel has a header file declaration, but no corresponding - # implementation can be found, this is not allowed - if (${cpu_srcs_len} EQUAL 0 AND ${gpu_srcs_len} EQUAL 0 AND - ${xpu_srcs_len} EQUAL 0) + if (${common_srcs_len} EQUAL 0) message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") else() + # If the kernel has a device independent public implementation, + # we will use this implementation and will not adopt the implementation + # under specific devices if (WITH_GPU) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - endif() + nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) elseif (WITH_ROCM) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - endif() + hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) else() - if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) - cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - endif() + cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() endif() endif() diff --git a/paddle/fluid/operators/cholesky_solve_op.h b/paddle/fluid/operators/cholesky_solve_op.h index 157679f4fc90b4..da9ae69338cb5d 100644 --- a/paddle/fluid/operators/cholesky_solve_op.h +++ b/paddle/fluid/operators/cholesky_solve_op.h @@ -202,7 +202,8 @@ class CholeskySolveGradKernel : public framework::OpKernel { commonterm_for_range(commonterm_functor); commonterm_conj = helper.Transpose(commonterm_conj); - pten::AddKernel(dev_ctx, commonterm, commonterm_conj, -1, &commonterm); + pten::AddRawKernel(dev_ctx, commonterm, commonterm_conj, -1, + &commonterm); auto mat_dim_u = math::CreateMatrixDescriptor(u_bst.dims(), 0, false); auto mat_dim_c = diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index 622a6d7edb783c..2ee97fc5a63aa1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -61,7 +61,7 @@ class ElementwiseAddKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::AddKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get()); + pten::AddRawKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index d9f7bbc56a9021..d96c621fbbe3a6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -51,7 +51,8 @@ class ElementwiseDivKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::DivideKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get()); + pten::DivideRawKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, + pt_z.get()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 5ece5cadc603fa..4fa39ffef40c8f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -50,8 +50,8 @@ class ElementwiseMulKernel auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y_lod); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod); - pten::MultiplyKernel(cuda_ctx, *pt_x.get(), *pt_y.get(), axis, - pt_z.get()); + pten::MultiplyRawKernel(cuda_ctx, *pt_x.get(), *pt_y.get(), axis, + pt_z.get()); } else { PADDLE_THROW(platform::errors::InvalidArgument( "X's type[%s] is not supported by elementwise_op. X's type should be " diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 687340b668a13f..e75eae53b1c7be 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -124,8 +124,8 @@ class ElementwiseMulKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod); - pten::MultiplyKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, - pt_z.get()); + pten::MultiplyRawKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, + pt_z.get()); } else { PADDLE_THROW(platform::errors::InvalidArgument( "X's type[%s] is not supported by elementwise_op. X's type should be " diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index e1d9655e293a3e..9e2ffbbb21b90b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -142,24 +142,25 @@ class ElementwiseOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { if (Type() == "elementwise_add") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("add", {"X", "Y"}, {"axis"}, {"Out"}); + return framework::KernelSignature("add_raw", {"X", "Y"}, {"axis"}, + {"Out"}); } } if (Type() == "elementwise_sub") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("subtract", {"X", "Y"}, {"axis"}, + return framework::KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } if (Type() == "elementwise_div") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("divide", {"X", "Y"}, {"axis"}, + return framework::KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } if (Type() == "elementwise_mul") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("multiply", {"X", "Y"}, {"axis"}, + return framework::KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 0d889ef26c954f..992dfd4262675d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -51,8 +51,8 @@ class ElementwiseSubKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::SubtractKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, - pt_z.get()); + pten::SubtractRawKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, + pt_z.get()); } }; diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h index f241caa857a07a..57aac0fc005180 100644 --- a/paddle/fluid/operators/lu_op.h +++ b/paddle/fluid/operators/lu_op.h @@ -221,7 +221,7 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1, out->Resize(src1.dims()); out->mutable_data(dev_ctx.GetPlace()); - pten::AddKernel(dev_ctx, src1, src2, -1, out); + pten::AddRawKernel(dev_ctx, src1, src2, -1, out); } template @@ -230,7 +230,7 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1, out->Resize(src1.dims()); out->mutable_data(dev_ctx.GetPlace()); - pten::SubtractKernel(dev_ctx, src1, src2, -1, out); + pten::SubtractRawKernel(dev_ctx, src1, src2, -1, out); } template diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index eb4d4a5c1680ec..a6c51e79a258f5 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -550,14 +550,14 @@ class ReduceOp : public framework::OperatorWithKernel { if (Type() == "reduce_sum") { if (ctx.InputVar("X")->IsType()) { return framework::KernelSignature( - "sum", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"}, + "sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"}, {"Out"}); } } if (Type() == "reduce_mean") { if (ctx.InputVar("X")->IsType()) { return framework::KernelSignature( - "mean", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); } } // TODO(chentianyu03): support other cases after selected rows added diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index 0b17415a6a98de..1ad2f5da5d66ad 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -30,7 +30,6 @@ using DeviceContext = paddle::platform::DeviceContext; using add_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using cast_kernel = void (*)(const DeviceContext&, @@ -41,7 +40,6 @@ using cast_kernel = void (*)(const DeviceContext&, using divide_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using dot_kernel = void (*)(const DeviceContext&, @@ -77,13 +75,11 @@ using mean_kernel = void (*)(const DeviceContext&, const DenseTensor&, const std::vector&, bool, - bool, DenseTensor*); using multiply_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using reshape_kernel = void (*)(const DeviceContext&, @@ -102,14 +98,11 @@ using sum_kernel = void (*)(const DeviceContext&, const DenseTensor&, const std::vector&, bool, - bool, - DataType, DenseTensor*); using subtract_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using conj_kernel = void (*)(const DeviceContext&, diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 5c867879663684..b2c568d6ad4a7e 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -20,10 +20,6 @@ namespace pten { // the key is kernel_name in fluid, the value is the kernel_name in pten // the key is sorted by key's alphabet const std::unordered_map kernel_alias_name_map = { - {"elementwise_add", "add"}, - {"elementwise_div", "divide"}, - {"elementwise_mul", "muliply"}, - {"elementwise_sub", "subtract"}, {"fill_any_like", "full_like"}, {"fill_constant", "full"}, {"flatten_contiguous_range", "flatten"}, @@ -32,8 +28,6 @@ const std::unordered_map kernel_alias_name_map = { {"matmul_v2_grad", "matmul_grad"}, {"matmul_v2_grad_grad", "matmul_double_grad"}, {"matmul_v2_triple_grad", "matmul_triple_grad"}, - {"reduce_mean", "mean"}, - {"reduce_sum", "sum"}, {"reshape2", "reshape"}, {"reshape2_grad", "reshape_grad"}, {"reshape2_grad_grad", "reshape_double_grad"}, diff --git a/paddle/pten/kernels/cpu/math_kernel.cc b/paddle/pten/kernels/cpu/math_kernel.cc index 83388d0d9a80fd..cbdfa38dbe6300 100644 --- a/paddle/pten/kernels/cpu/math_kernel.cc +++ b/paddle/pten/kernels/cpu/math_kernel.cc @@ -32,11 +32,11 @@ namespace pten { #define DEFINE_CPU_ELEMENTWISE_OP(name) \ template \ - void name##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - int axis, \ - DenseTensor* out) { \ + void name##RawKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + int axis, \ + DenseTensor* out) { \ out->mutable_data(); \ if (x.dims() == y.dims()) { \ SameDimsElementwiseCompute>()( \ @@ -55,23 +55,35 @@ namespace pten { } template -void MeanKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void MeanRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { auto out_dtype = x.dtype(); pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } template -void DivideKernel(const Context& dev_ctx, +void SumRawKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, DenseTensor* out) { + pten::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + +template +void DivideRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { // allocate memory for out out->mutable_data(); if (x.dims() == y.dims() && std::is_floating_point::value) { @@ -90,18 +102,6 @@ void DivideKernel(const Context& dev_ctx, } } -template -void SumKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out) { - pten::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); -} - // Create the definition of Add DEFINE_CPU_ELEMENTWISE_OP(Add) @@ -118,42 +118,40 @@ using complex128 = ::paddle::platform::complex; // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // using bfloat16 = ::paddle::platform::bfloat16; -PT_REGISTER_KERNEL( - mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} -PT_REGISTER_KERNEL(add, +PT_REGISTER_KERNEL(add_raw, CPU, ALL_LAYOUT, - pten::AddKernel, + pten::AddRawKernel, float, double, int, int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL(subtract, +PT_REGISTER_KERNEL(subtract_raw, CPU, ALL_LAYOUT, - pten::SubtractKernel, + pten::SubtractRawKernel, float, double, int, int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL(divide, +PT_REGISTER_KERNEL(divide_raw, CPU, ALL_LAYOUT, - pten::DivideKernel, + pten::DivideRawKernel, float, double, int, int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL(multiply, +PT_REGISTER_KERNEL(multiply_raw, CPU, ALL_LAYOUT, - pten::MultiplyKernel, + pten::MultiplyRawKernel, float, double, int, @@ -161,10 +159,10 @@ PT_REGISTER_KERNEL(multiply, bool, complex64, complex128) {} -PT_REGISTER_KERNEL(sum, +PT_REGISTER_KERNEL(sum_raw, CPU, ALL_LAYOUT, - pten::SumKernel, + pten::SumRawKernel, bool, float, double, @@ -175,3 +173,5 @@ PT_REGISTER_KERNEL(sum, complex128) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } +PT_REGISTER_KERNEL( + mean_raw, CPU, ALL_LAYOUT, pten::MeanRawKernel, float, double, bool) {} diff --git a/paddle/pten/kernels/gpu/math_kernel.cu b/paddle/pten/kernels/gpu/math_kernel.cu index 1fd085ab5fe409..e2a09aaaa896b0 100644 --- a/paddle/pten/kernels/gpu/math_kernel.cu +++ b/paddle/pten/kernels/gpu/math_kernel.cu @@ -37,11 +37,11 @@ namespace pten { #define DEFINE_CUDA_ELEMENTWISE_OP(name) \ template \ - void name##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - int axis, \ - DenseTensor* out) { \ + void name##RawKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + int axis, \ + DenseTensor* out) { \ std::vector inputs; \ std::vector outputs; \ inputs.emplace_back(&x); \ @@ -72,17 +72,29 @@ struct DivideFunctor { */ template -void MeanKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void MeanRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { auto out_dtype = x.dtype(); pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } +template +void SumRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, + DenseTensor* out) { + pten::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + // Create the definition of Add DEFINE_CUDA_ELEMENTWISE_OP(Add) // Create the definition of Subtract @@ -92,30 +104,16 @@ DEFINE_CUDA_ELEMENTWISE_OP(Multiply) // Create the definition of Divide DEFINE_CUDA_ELEMENTWISE_OP(Divide) -template -void SumKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out) { - pten::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); -} - } // namespace pten using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_KERNEL( - mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} -PT_REGISTER_KERNEL(add, +PT_REGISTER_KERNEL(add_raw, GPU, ALL_LAYOUT, - pten::AddKernel, + pten::AddRawKernel, float, double, int, @@ -123,10 +121,10 @@ PT_REGISTER_KERNEL(add, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(subtract, +PT_REGISTER_KERNEL(subtract_raw, GPU, ALL_LAYOUT, - pten::SubtractKernel, + pten::SubtractRawKernel, float, double, int, @@ -134,10 +132,10 @@ PT_REGISTER_KERNEL(subtract, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(divide, +PT_REGISTER_KERNEL(divide_raw, GPU, ALL_LAYOUT, - pten::DivideKernel, + pten::DivideRawKernel, float, double, int, @@ -145,10 +143,10 @@ PT_REGISTER_KERNEL(divide, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(multiply, +PT_REGISTER_KERNEL(multiply_raw, GPU, ALL_LAYOUT, - pten::MultiplyKernel, + pten::MultiplyRawKernel, float, double, int, @@ -157,10 +155,10 @@ PT_REGISTER_KERNEL(multiply, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(sum, +PT_REGISTER_KERNEL(sum_raw, GPU, ALL_LAYOUT, - pten::SumKernel, + pten::SumRawKernel, bool, float, double, @@ -171,3 +169,12 @@ PT_REGISTER_KERNEL(sum, complex128) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } + +PT_REGISTER_KERNEL(mean_raw, + GPU, + ALL_LAYOUT, + pten::MeanRawKernel, + float, + double, + bool, + float16) {} diff --git a/paddle/pten/kernels/math_kernel.cc b/paddle/pten/kernels/math_kernel.cc new file mode 100644 index 00000000000000..620ad3f0b92b8d --- /dev/null +++ b/paddle/pten/kernels/math_kernel.cc @@ -0,0 +1,204 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/math_kernel.h" + +#include "paddle/pten/backends/all_context.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { + +template +void MeanKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out) { + bool reduce_all = false; + MeanRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); +} + +template +void SumKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out) { + bool reduce_all = false; + DataType out_dtype = DataType::UNDEFINED; + SumRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out); +} + +template +void AddKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + AddRawKernel(dev_ctx, x, y, axis, out); +} + +template +void SubtractKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + SubtractRawKernel(dev_ctx, x, y, axis, out); +} + +template +void DivideKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + DivideRawKernel(dev_ctx, x, y, axis, out); +} + +template +void MultiplyKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + MultiplyRawKernel(dev_ctx, x, y, axis, out); +} + +} // namespace pten + +using float16 = paddle::platform::float16; +using complex64 = ::paddle::platform::complex; +using complex128 = ::paddle::platform::complex; + +PT_REGISTER_KERNEL( + mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} +PT_REGISTER_KERNEL( + mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} +PT_REGISTER_KERNEL(sum, + CPU, + ALL_LAYOUT, + pten::SumKernel, + bool, + float, + double, + paddle::platform::float16, + int, + int64_t, + complex64, + complex128) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} +PT_REGISTER_KERNEL(sum, + GPU, + ALL_LAYOUT, + pten::SumKernel, + bool, + float, + double, + float16, + int, + int64_t, + complex64, + complex128) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} + +PT_REGISTER_KERNEL(add, + CPU, + ALL_LAYOUT, + pten::AddKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(subtract, + CPU, + ALL_LAYOUT, + pten::SubtractKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(divide, + CPU, + ALL_LAYOUT, + pten::DivideKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(multiply, + CPU, + ALL_LAYOUT, + pten::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + complex64, + complex128) {} + +PT_REGISTER_KERNEL(add, + GPU, + ALL_LAYOUT, + pten::AddKernel, + float, + double, + int, + int64_t, + float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(subtract, + GPU, + ALL_LAYOUT, + pten::SubtractKernel, + float, + double, + int, + int64_t, + float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(divide, + GPU, + ALL_LAYOUT, + pten::DivideKernel, + float, + double, + int, + int64_t, + float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(multiply, + GPU, + ALL_LAYOUT, + pten::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + float16, + complex64, + complex128) {} diff --git a/paddle/pten/kernels/math_kernel.h b/paddle/pten/kernels/math_kernel.h index 65c0f84e696dea..417e43d031b818 100644 --- a/paddle/pten/kernels/math_kernel.h +++ b/paddle/pten/kernels/math_kernel.h @@ -22,104 +22,126 @@ limitations under the License. */ namespace pten { +template +void MeanRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out); + template void MeanKernel(const Context& dev_ctx, const DenseTensor& x, const std::vector& dims, bool keep_dim, - bool reduce_all, DenseTensor* out); +template +void SumRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, + DenseTensor* out); + +template +void SumKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out); + +template +void AddRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void AddKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); +template +void SubtractRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void SubtractKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); +template +void DivideRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void DivideKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); +template +void MultiplyRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void MultiplyKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); -template -void SumKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out); - template DenseTensor Add(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - AddKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + AddKernel(dev_ctx, x, y, &dense_out); return dense_out; } template DenseTensor Subtract(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - SubtractKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + SubtractKernel(dev_ctx, x, y, &dense_out); return dense_out; } template DenseTensor Divide(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - DivideKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + DivideKernel(dev_ctx, x, y, &dense_out); return dense_out; } template DenseTensor Multiply(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - MultiplyKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + MultiplyKernel(dev_ctx, x, y, &dense_out); return dense_out; } @@ -130,8 +152,7 @@ DenseTensor Mean(const Context& dev_ctx, bool keep_dim) { auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); - bool reduce_all = false; - MeanKernel(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out); + MeanKernel(dev_ctx, x, axis, keep_dim, &dense_out); return dense_out; } @@ -144,12 +165,7 @@ DenseTensor Sum(const Context& dev_ctx, auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype); auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); - // The real value of reduce_all will be get in kernel - // so use default value(false) is OK. - bool reduce_all = false; - - SumKernel( - dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out); + SumKernel(dev_ctx, x, axis, keep_dim, &dense_out); return dense_out; } diff --git a/paddle/pten/tests/kernels/test_elementwise_dev_api.cc b/paddle/pten/tests/kernels/test_elementwise_dev_api.cc index 40998a8d57caa2..78db5f7de47aba 100644 --- a/paddle/pten/tests/kernels/test_elementwise_dev_api.cc +++ b/paddle/pten/tests/kernels/test_elementwise_dev_api.cc @@ -53,7 +53,6 @@ TEST(DEV_API, add) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0; } - int axis = 1; paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); @@ -62,8 +61,7 @@ TEST(DEV_API, add) { auto dense_out = pten::Add( *(static_cast(dev_ctx)), dense_x, - dense_y, - axis); + dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -106,7 +104,6 @@ TEST(DEV_API, subtract) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0; } - int axis = 1; paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); @@ -115,8 +112,7 @@ TEST(DEV_API, subtract) { auto dense_out = pten::Subtract( *(static_cast(dev_ctx)), dense_x, - dense_y, - axis); + dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -159,7 +155,6 @@ TEST(DEV_API, divide) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0 + 1; } - int axis = 1; paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); @@ -168,8 +163,7 @@ TEST(DEV_API, divide) { auto dense_out = pten::Divide( *(static_cast(dev_ctx)), dense_x, - dense_y, - axis); + dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -212,7 +206,6 @@ TEST(DEV_API, multiply) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0; } - int axis = 1; paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); @@ -221,8 +214,7 @@ TEST(DEV_API, multiply) { auto dense_out = pten::Multiply( *(static_cast(dev_ctx)), dense_x, - dense_y, - axis); + dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 562a726aa29f27..6fae91c0417d91 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -6,7 +6,6 @@ param : [x, y, -1] kernel : func : add - param : [x, y, -1] - api : cast args : (const Tensor& x, DataType out_dtype) @@ -34,7 +33,6 @@ param : [x, y, -1] kernel : func : divide - param : [x, y, -1] - api : dot args : (const Tensor& x, const Tensor& y) @@ -120,7 +118,6 @@ param: [x, axis, keep_dim] kernel : func : mean - param : [x, axis, keep_dim, false] - api : multiply args : (const Tensor& x, const Tensor& y) @@ -130,7 +127,6 @@ param : [x, y, -1] kernel : func : multiply - param : [x, y, -1] - api : ones_like args : (const Tensor& x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED) @@ -162,7 +158,6 @@ param : [x, y, -1] kernel : func : subtract - param : [x, y, -1] - api : sum args : (const Tensor& x, const std::vector& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false) @@ -172,7 +167,7 @@ param: [x, axis, keep_dim, dtype] kernel : func : sum - param : [x, axis, keep_dim, false, DataType::UNDEFINED] + param : [x, axis, keep_dim] data_type : x - api : zeros_like From 84e4328fe6567aacaf818903061b7dd62fc18178 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Mon, 17 Jan 2022 13:12:09 +0000 Subject: [PATCH 02/10] fix compile bugs --- paddle/pten/kernels/math_kernel.cc | 35 ++++++++++++++++-------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/paddle/pten/kernels/math_kernel.cc b/paddle/pten/kernels/math_kernel.cc index 620ad3f0b92b8d..fc28380eeb805b 100644 --- a/paddle/pten/kernels/math_kernel.cc +++ b/paddle/pten/kernels/math_kernel.cc @@ -84,8 +84,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_KERNEL( mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} -PT_REGISTER_KERNEL( - mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} + PT_REGISTER_KERNEL(sum, CPU, ALL_LAYOUT, @@ -100,20 +99,6 @@ PT_REGISTER_KERNEL(sum, complex128) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } -PT_REGISTER_KERNEL(sum, - GPU, - ALL_LAYOUT, - pten::SumKernel, - bool, - float, - double, - float16, - int, - int64_t, - complex64, - complex128) { - kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); -} PT_REGISTER_KERNEL(add, CPU, @@ -157,6 +142,23 @@ PT_REGISTER_KERNEL(multiply, complex64, complex128) {} +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_REGISTER_KERNEL( + mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} +PT_REGISTER_KERNEL(sum, + GPU, + ALL_LAYOUT, + pten::SumKernel, + bool, + float, + double, + float16, + int, + int64_t, + complex64, + complex128) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} PT_REGISTER_KERNEL(add, GPU, ALL_LAYOUT, @@ -202,3 +204,4 @@ PT_REGISTER_KERNEL(multiply, float16, complex64, complex128) {} +#endif From 7c0a653366e3d34781b14e9365d6fb447d183369 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 18 Jan 2022 02:46:55 +0000 Subject: [PATCH 03/10] fix kunlun compile bugs --- paddle/pten/kernels/math_kernel.cc | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/paddle/pten/kernels/math_kernel.cc b/paddle/pten/kernels/math_kernel.cc index fc28380eeb805b..c86adf158e678f 100644 --- a/paddle/pten/kernels/math_kernel.cc +++ b/paddle/pten/kernels/math_kernel.cc @@ -78,7 +78,6 @@ void MultiplyKernel(const Context& dev_ctx, } // namespace pten -using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; @@ -143,8 +142,14 @@ PT_REGISTER_KERNEL(multiply, complex128) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_REGISTER_KERNEL( - mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} +PT_REGISTER_KERNEL(mean, + GPU, + ALL_LAYOUT, + pten::MeanKernel, + float, + double, + bool, + paddle::platform::float16) {} PT_REGISTER_KERNEL(sum, GPU, ALL_LAYOUT, @@ -152,7 +157,7 @@ PT_REGISTER_KERNEL(sum, bool, float, double, - float16, + paddle::platform::float16, int, int64_t, complex64, @@ -167,7 +172,7 @@ PT_REGISTER_KERNEL(add, double, int, int64_t, - float16, + paddle::platform::float16, complex64, complex128) {} PT_REGISTER_KERNEL(subtract, @@ -178,7 +183,7 @@ PT_REGISTER_KERNEL(subtract, double, int, int64_t, - float16, + paddle::platform::float16, complex64, complex128) {} PT_REGISTER_KERNEL(divide, @@ -189,7 +194,7 @@ PT_REGISTER_KERNEL(divide, double, int, int64_t, - float16, + paddle::platform::float16, complex64, complex128) {} PT_REGISTER_KERNEL(multiply, @@ -201,7 +206,7 @@ PT_REGISTER_KERNEL(multiply, int, int64_t, bool, - float16, + paddle::platform::float16, complex64, complex128) {} #endif From deee972c734bc6d707772b7d77d53b2a9a3addb9 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 18 Jan 2022 12:57:49 +0000 Subject: [PATCH 04/10] perfect cmake --- cmake/pten_kernel.cmake | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index 9d5d4957b4ef42..8f4d75a88dbcb7 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -19,6 +19,7 @@ function(kernel_declare TARGET_LIST) # TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL # NOTE(chenweihang): now we don't recommend to use digit in kernel name string(REGEX MATCH "(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}") + message("regex kernel decaler: ${first_registry}") if (NOT first_registry STREQUAL "") # parse the first kernel name string(REPLACE "PT_REGISTER_KERNEL(" "" kernel_name "${first_registry}") @@ -104,35 +105,20 @@ function(kernel_library TARGET) list(LENGTH xpu_srcs xpu_srcs_len) if(${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR - ${xpu_srcs_len} GREATER 0) + ${xpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) if (WITH_GPU) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() elseif (WITH_ROCM) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() else() - if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() endif() - else() - if (${common_srcs_len} EQUAL 0) - message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") - else() - # If the kernel has a device independent public implementation, - # we will use this implementation and will not adopt the implementation - # under specific devices - if (WITH_GPU) - nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - elseif (WITH_ROCM) - hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - else() - cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - endif() - endif() endif() if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR From f25b3c4bcfcad55b151e4457e8b9cd75e61fea42 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Thu, 20 Jan 2022 02:26:45 +0000 Subject: [PATCH 05/10] fix compile bugs when run ci-inference --- cmake/pten_kernel.cmake | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index 8f4d75a88dbcb7..b48905902687ca 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -19,7 +19,6 @@ function(kernel_declare TARGET_LIST) # TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL # NOTE(chenweihang): now we don't recommend to use digit in kernel name string(REGEX MATCH "(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}") - message("regex kernel decaler: ${first_registry}") if (NOT first_registry STREQUAL "") # parse the first kernel name string(REPLACE "PT_REGISTER_KERNEL(" "" kernel_name "${first_registry}") @@ -108,15 +107,15 @@ function(kernel_library TARGET) ${xpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) if (WITH_GPU) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) - nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + nv_library(${TARGET} SRCS ${common_srcs} ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() elseif (WITH_ROCM) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) - hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + hip_library(${TARGET} SRCS ${common_srcs} ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() else() if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) - cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + cc_library(${TARGET} SRCS ${common_srcs} ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() endif() endif() From 59864105b333b1f28e2f124ec7ba5005123eb581 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Thu, 20 Jan 2022 08:36:39 +0000 Subject: [PATCH 06/10] fix compile bugs --- cmake/pten_kernel.cmake | 50 ++++++++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index b48905902687ca..c2928376a02f8f 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -103,21 +103,55 @@ function(kernel_library TARGET) list(LENGTH gpu_srcs gpu_srcs_len) list(LENGTH xpu_srcs xpu_srcs_len) - if(${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR - ${xpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) + # Build Target according different src organization + if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR + ${xpu_srcs_len} GREATER 0) AND ${common_srcs_len} GREATER 0) + # If the common_srcs depends on specific device srcs, build target using this rule. if (WITH_GPU) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) - nv_library(${TARGET} SRCS ${common_srcs} ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + nv_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part) endif() elseif (WITH_ROCM) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) - hip_library(${TARGET} SRCS ${common_srcs} ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + hip_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part) endif() else() - if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${common_srcs_len} GREATER 0) - cc_library(${TARGET} SRCS ${common_srcs} ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + cc_library(${TARGET}_part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + cc_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part) endif() endif() + elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + if (WITH_GPU) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() + elseif (WITH_ROCM) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() + else() + if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() + endif() + else() + if (${common_srcs_len} EQUAL 0) + message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") + else() + # If the kernel has a device independent public implementation, + # we will use this implementation and will not adopt the implementation + # under specific devices + if (WITH_GPU) + nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + elseif (WITH_ROCM) + hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + else() + cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() + endif() endif() if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR From f1f887442d9d520e92fd963d2e79236824cad782 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 21 Jan 2022 02:55:56 +0000 Subject: [PATCH 07/10] add non-raw kernel for fluid op --- .../fluid/operators/elementwise/elementwise_op.h | 15 +++++++++++++++ paddle/fluid/operators/reduce_ops/reduce_op.h | 9 +++++++++ 2 files changed, 24 insertions(+) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 9e2ffbbb21b90b..aaf33ca6744886 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -140,26 +140,41 @@ class ElementwiseOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext &ctx) const override { + int axis = ctx.Attr("axis"); if (Type() == "elementwise_add") { if (ctx.InputVar("X")->IsType()) { + if (axis == -1) { + return framework::KernelSignature("add", {"X", "Y"}, {}, {"Out"}); + } return framework::KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } if (Type() == "elementwise_sub") { if (ctx.InputVar("X")->IsType()) { + if (axis == -1) { + return framework::KernelSignature("subtract", {"X", "Y"}, {}, + {"Out"}); + } return framework::KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } if (Type() == "elementwise_div") { if (ctx.InputVar("X")->IsType()) { + if (axis == -1) { + return framework::KernelSignature("divide", {"X", "Y"}, {}, {"Out"}); + } return framework::KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } if (Type() == "elementwise_mul") { if (ctx.InputVar("X")->IsType()) { + if (axis == -1) { + return framework::KernelSignature("multiply", {"X", "Y"}, {}, + {"Out"}); + } return framework::KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"}); } diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index ffec9f92f49bcd..7200963a7097fe 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -551,8 +551,13 @@ class ReduceOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext& ctx) const override { + bool reduce_all = ctx.Attr("reduce_all"); if (Type() == "reduce_sum") { if (ctx.InputVar("X")->IsType()) { + if (!reduce_all) { + return framework::KernelSignature( + "sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"}); + } return framework::KernelSignature( "sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"}, {"Out"}); @@ -560,6 +565,10 @@ class ReduceOp : public framework::OperatorWithKernel { } if (Type() == "reduce_mean") { if (ctx.InputVar("X")->IsType()) { + if (!reduce_all) { + return framework::KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, + {"Out"}); + } return framework::KernelSignature( "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); } From 5aca6fb4a196c44d4666e5ed1634c6fd1c0d64ad Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 21 Jan 2022 03:17:54 +0000 Subject: [PATCH 08/10] fix compile bugs --- paddle/pten/kernels/math_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/pten/kernels/math_kernel.h b/paddle/pten/kernels/math_kernel.h index b112ebd718e442..95379baaf35043 100644 --- a/paddle/pten/kernels/math_kernel.h +++ b/paddle/pten/kernels/math_kernel.h @@ -166,7 +166,7 @@ DenseTensor Sum(const Context& dev_ctx, auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype); auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); - SumKernel(dev_ctx, x, axis, keep_dim, &dense_out); + SumKernel(dev_ctx, x, axis, keep_dim, dtype, &dense_out); return dense_out; } From 794e60f3ad7be010be5ac4dfcc2361987f12d14c Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 21 Jan 2022 05:52:21 +0000 Subject: [PATCH 09/10] fix compile bugs --- paddle/pten/api/include/kernel_signature.h | 1 + paddle/pten/core/kernel_alias_name.h | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index 1ad2f5da5d66ad..76c03f93a3d93b 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -98,6 +98,7 @@ using sum_kernel = void (*)(const DeviceContext&, const DenseTensor&, const std::vector&, bool, + DataType, DenseTensor*); using subtract_kernel = void (*)(const DeviceContext&, diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index b2c568d6ad4a7e..8e089970f9139e 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -20,6 +20,10 @@ namespace pten { // the key is kernel_name in fluid, the value is the kernel_name in pten // the key is sorted by key's alphabet const std::unordered_map kernel_alias_name_map = { + {"elementwise_add", "add_raw"}, + {"elementwise_div", "divide_raw"}, + {"elementwise_mul", "muliply_raw"}, + {"elementwise_sub", "subtract_raw"}, {"fill_any_like", "full_like"}, {"fill_constant", "full"}, {"flatten_contiguous_range", "flatten"}, @@ -28,6 +32,8 @@ const std::unordered_map kernel_alias_name_map = { {"matmul_v2_grad", "matmul_grad"}, {"matmul_v2_grad_grad", "matmul_double_grad"}, {"matmul_v2_triple_grad", "matmul_triple_grad"}, + {"reduce_mean", "mean_raw"}, + {"reduce_sum", "sum_raw"}, {"reshape2", "reshape"}, {"reshape2_grad", "reshape_grad"}, {"reshape2_grad_grad", "reshape_double_grad"}, From 90202188ac4108d0f0168d5618bbedf70b9ba08d Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 21 Jan 2022 06:26:27 +0000 Subject: [PATCH 10/10] fix unit test bug --- paddle/fluid/operators/reduce_ops/reduce_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 7200963a7097fe..2e5bd7a42b1d1a 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -551,7 +551,7 @@ class ReduceOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext& ctx) const override { - bool reduce_all = ctx.Attr("reduce_all"); + bool reduce_all = ctx.Attr("reduce_all"); if (Type() == "reduce_sum") { if (ctx.InputVar("X")->IsType()) { if (!reduce_all) {