From 3f668d14839358d362002a54db3eabd89dce494b Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Tue, 15 Mar 2022 14:51:18 +0000 Subject: [PATCH 1/6] move clip op to phi --- paddle/fluid/operators/clip_op.cc | 27 +-- paddle/fluid/operators/clip_op.cu | 32 --- paddle/fluid/operators/clip_op.h | 196 ------------------ paddle/fluid/operators/clip_op_npu.cc | 2 +- paddle/fluid/operators/clip_op_xpu.cc | 1 - paddle/fluid/operators/fake_quantize_op.cc | 14 +- .../fluid/operators/hierarchical_sigmoid_op.h | 4 +- .../operators/math/selected_rows_functor.cc | 177 ++++++++++++---- .../operators/math/selected_rows_functor.cu | 196 ++++++++++++++++-- paddle/phi/kernels/CMakeLists.txt | 2 +- paddle/phi/kernels/clip_grad_kernel.h | 31 +++ paddle/phi/kernels/clip_kernel.h | 31 +++ paddle/phi/kernels/cpu/clip_grad_kernel.cc | 24 +++ paddle/phi/kernels/cpu/clip_kernel.cc | 18 ++ paddle/phi/kernels/gpu/clip_grad_kernel.cu | 24 +++ paddle/phi/kernels/gpu/clip_kernel.cu | 26 +++ .../phi/kernels/impl/clip_grad_kernel_impl.h | 74 +++++++ paddle/phi/kernels/impl/clip_kernel_impl.h | 79 +++++++ .../phi/kernels/selected_rows/clip_kernel.cc | 24 +++ .../phi/kernels/selected_rows/clip_kernel.cu | 26 +++ .../phi/kernels/selected_rows/clip_kernel.h | 61 ++++++ paddle/phi/ops/compat/clip_sig.cc | 47 +++++ 22 files changed, 797 insertions(+), 319 deletions(-) delete mode 100644 paddle/fluid/operators/clip_op.cu delete mode 100644 paddle/fluid/operators/clip_op.h create mode 100644 paddle/phi/kernels/clip_grad_kernel.h create mode 100644 paddle/phi/kernels/clip_kernel.h create mode 100644 paddle/phi/kernels/cpu/clip_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/clip_kernel.cc create mode 100644 paddle/phi/kernels/gpu/clip_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/clip_kernel.cu create mode 100644 paddle/phi/kernels/impl/clip_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/clip_kernel_impl.h create mode 100644 paddle/phi/kernels/selected_rows/clip_kernel.cc create mode 100644 paddle/phi/kernels/selected_rows/clip_kernel.cu create mode 100644 paddle/phi/kernels/selected_rows/clip_kernel.h create mode 100644 paddle/phi/ops/compat/clip_sig.cc diff --git a/paddle/fluid/operators/clip_op.cc b/paddle/fluid/operators/clip_op.cc index 436d1edcedf1e0..3f8755a6203b51 100644 --- a/paddle/fluid/operators/clip_op.cc +++ b/paddle/fluid/operators/clip_op.cc @@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/clip_op.h" #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -23,15 +25,6 @@ namespace operators { class ClipOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "clip"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "clip"); - auto x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", x_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = @@ -176,23 +169,15 @@ class ClipDoubleGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(clip, ClipInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker, ops::ClipGradOpMaker, ops::ClipGradOpMaker, - ops::ClipInplaceInferer); + ops::ClipInplaceInferer, ClipInferShapeFunctor); REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer, ops::ClipDoubleGradOpMaker, ops::ClipDoubleGradOpMaker); -REGISTER_OP_CPU_KERNEL( - clip, ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel); -REGISTER_OP_CPU_KERNEL( - clip_grad, ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel); REGISTER_OP_VERSION(clip) .AddCheckpoint( diff --git a/paddle/fluid/operators/clip_op.cu b/paddle/fluid/operators/clip_op.cu deleted file mode 100644 index 846354fcb81c5f..00000000000000 --- a/paddle/fluid/operators/clip_op.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/clip_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - clip, ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel); - -REGISTER_OP_CUDA_KERNEL( - clip_grad, ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel); diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h deleted file mode 100644 index 3b815cd1fa74a6..00000000000000 --- a/paddle/fluid/operators/clip_op.h +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/platform/transform.h" -#if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#endif - -namespace paddle { -namespace operators { - -using framework::Tensor; -using platform::Transform; - -template -class ClipFunctor { - public: - explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} - HOSTDEVICE T operator()(const T x) const { - return x < min_ ? min_ : x > max_ ? max_ : x; - } - - private: - T min_; - T max_; -}; - -template -class ClipGradFunctor { - public: - explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} - HOSTDEVICE T operator()(const T x, const T y) const { - return (y > min_ && y < max_) ? x : static_cast(0); - } - - private: - T min_; - T max_; -}; - -template -class ClipKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto max = static_cast(context.Attr("max")); - Tensor max_cpu; - if (context.HasInput("Max")) { - auto* max_t = context.Input("Max"); - auto* max_data = max_t->data(); - if (platform::is_gpu_place(max_t->place())) { - paddle::framework::TensorCopySync(*max_t, platform::CPUPlace(), - &max_cpu); - max_data = max_cpu.data(); - } - max = max_data[0]; - } - max = static_cast(max); - - auto min = static_cast(context.Attr("min")); - Tensor min_cpu; - if (context.HasInput("Min")) { - auto* min_t = context.Input("Min"); - auto* min_data = min_t->data(); - if (platform::is_gpu_place(min_t->place())) { - paddle::framework::TensorCopySync(*min_t, platform::CPUPlace(), - &min_cpu); - min_data = min_cpu.data(); - } - min = min_data[0]; - } - - PADDLE_ENFORCE_LE(min, max, - platform::errors::InvalidArgument( - "max should be greater than or equal to min. " - "But received min = %f, max = %f", - static_cast(min), static_cast(max))); - - auto* x_var = context.InputVar("X"); - if (x_var->IsType()) { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - T* out_data = out->mutable_data(context.GetPlace()); - const T* x_data = x->data(); - int64_t numel = x->numel(); - if (platform::is_gpu_place(context.GetPlace())) { -#if defined(__NVCC__) || defined(__HIPCC__) - std::vector ins = {x}; - std::vector outs = {out}; - auto functor = ClipFunctor(min, max); - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - context.template device_context(), ins, - &outs, functor); -#endif - } else { - Transform trans; - trans(context.template device_context(), x_data, - x_data + numel, out_data, ClipFunctor(min, max)); - } - } else if (x_var->IsType()) { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - PADDLE_ENFORCE_NE(x, out, platform::errors::InvalidArgument( - "Inplace clip is not allowed " - "when x is SelectedRows")); - math::scatter::MergeAdd merge_func; - merge_func(context.template device_context(), *x, out); - auto* out_tensor = out->mutable_value(); - auto* out_data = out_tensor->data(); - int64_t numel = out_tensor->numel(); - Transform trans; - trans(context.template device_context(), out_data, - out_data + numel, out_data, ClipFunctor(min, max)); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "ClipOp only supports LoDTensor and SelectedRows.")); - } - } -}; - -template -class ClipGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto max = static_cast(context.Attr("max")); - Tensor max_cpu; - if (context.HasInput("Max")) { - auto* max_t = context.Input("Max"); - auto* max_data = max_t->data(); - if (platform::is_gpu_place(max_t->place())) { - paddle::framework::TensorCopySync(*max_t, platform::CPUPlace(), - &max_cpu); - max_data = max_cpu.data(); - } - max = max_data[0]; - } - max = static_cast(max); - - auto min = static_cast(context.Attr("min")); - Tensor min_cpu; - if (context.HasInput("Min")) { - auto* min_t = context.Input("Min"); - auto* min_data = min_t->data(); - if (platform::is_gpu_place(min_t->place())) { - paddle::framework::TensorCopySync(*min_t, platform::CPUPlace(), - &min_cpu); - min_data = min_cpu.data(); - } - min = min_data[0]; - } - min = static_cast(min); - - auto* d_out = - context.Input(framework::GradVarName("Out")); - auto* d_x = - context.Output(framework::GradVarName("X")); - if (d_x != nullptr) { - auto* x = context.Input("X"); -#if defined(__NVCC__) || defined(__HIPCC__) - std::vector ins = {d_out, x}; - std::vector outs = {d_x}; - auto functor = ClipGradFunctor(min, max); - d_x->mutable_data(context.GetPlace()); - LaunchSameDimsElementwiseCudaKernel( - context.template device_context(), ins, - &outs, functor); -#else - int64_t numel = d_out->numel(); - auto* d_x_data = d_x->mutable_data(context.GetPlace()); - const T* d_out_data = d_out->data(); - const T* x_data = x->data(); - Transform trans; - trans(context.template device_context(), d_out_data, - d_out_data + numel, x_data, d_x_data, ClipGradFunctor(min, max)); -#endif - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/clip_op_npu.cc b/paddle/fluid/operators/clip_op_npu.cc index 372ba707329bb3..32b17f94253a71 100644 --- a/paddle/fluid/operators/clip_op_npu.cc +++ b/paddle/fluid/operators/clip_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/clip_op_xpu.cc b/paddle/fluid/operators/clip_op_xpu.cc index c53bb2d9e4d0cb..b11a6717c63bb6 100644 --- a/paddle/fluid/operators/clip_op_xpu.cc +++ b/paddle/fluid/operators/clip_op_xpu.cc @@ -14,7 +14,6 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 4544386718813c..ac72f23d46ea84 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -17,8 +17,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/platform/transform.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" namespace paddle { namespace operators { @@ -91,7 +91,7 @@ struct ClipAndFakeQuantFunctor { T inv_s = inverse(s); platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); + out->mutable_data(ctx.GetPlace()), phi::ClipFunctor(-s, s)); auto out_e = framework::EigenVector::Flatten(*out); out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); } @@ -109,7 +109,7 @@ struct ClipAndFakeQuantDequantFunctor { platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); + out->mutable_data(ctx.GetPlace()), phi::ClipFunctor(-s, s)); auto out_e = framework::EigenVector::Flatten(*out); out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round() * s / static_cast(bin_cnt); @@ -144,7 +144,7 @@ struct ChannelClipAndFakeQuantFunctor { auto* start = in_data + i * channel_size; auto* end = in_data + (i + 1) * channel_size; trans(ctx, start, end, out_data + i * channel_size, - ClipFunctor(-s, s)); + phi::ClipFunctor(-s, s)); } for (int64_t i = 0; i < channel; i++) { T s = scale_data[i]; @@ -163,7 +163,7 @@ struct ChannelClipAndFakeQuantFunctor { auto* start = in_data + i * step_i + j * step_j; auto* end = in_data + i * step_i + (j + 1) * step_j; auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, ClipFunctor(-s, s)); + trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); for (int k = 0; k < step_j; k++) { cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); } @@ -200,7 +200,7 @@ struct ChannelClipFakeQuantDequantFunctor { auto* start = in_data + i * channel_size; auto* end = in_data + (i + 1) * channel_size; trans(ctx, start, end, out_data + i * channel_size, - ClipFunctor(-s, s)); + phi::ClipFunctor(-s, s)); } for (int i = 0; i < channel; i++) { T s = scale_data[i]; @@ -220,7 +220,7 @@ struct ChannelClipFakeQuantDequantFunctor { auto* start = in_data + i * step_i + j * step_j; auto* end = in_data + i * step_i + (j + 1) * step_j; auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, ClipFunctor(-s, s)); + trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); for (int k = 0; k < step_j; k++) { cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) * s / static_cast(bin_cnt); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index f11b28cfefb071..34476be03f60d3 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -23,10 +23,10 @@ limitations under the License. */ #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/platform/transform.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" namespace paddle { namespace operators { @@ -108,7 +108,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { Transform trans; trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out->numel(), pre_out_data, - ClipFunctor(static_cast(-40.0), static_cast(40.0))); + phi::ClipFunctor(static_cast(-40.0), static_cast(40.0))); bit_code->Sum(*pre_out, out, static_cast(-1)); // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 5ac39953462b50..0ca2529f132a0b 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -279,6 +279,46 @@ struct SelectedRowsAddToTensor { } }; +template +struct SelectedRowsAddToTensor { + void operator()(const phi::CPUContext& context, + const phi::SelectedRows& input1, framework::Tensor* input2) { + if (UNLIKELY(input1.rows().size() == 0)) { + LOG(WARNING) << "input selected rows is empty!"; + return; + } + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ( + in1_height, in2_dims[0], + platform::errors::InvalidArgument("The two inputs height must be equal." + "But recieved first input height = " + "[%d], second input height = [%d]", + in1_height, in2_dims[0])); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ( + in1_row_numel, input2->numel() / in1_height, + platform::errors::InvalidArgument( + "The two inputs width must be equal." + "But recieved first input width = [%d], second input width = [%d]", + in1_row_numel, input2->numel() / in1_height)); + + auto* in1_data = in1_value.data(); + auto* input2_data = input2->data(); + + for (size_t i = 0; i < in1_rows.size(); i++) { + for (int64_t j = 0; j < in1_row_numel; j++) { + input2_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + } + } + } +}; + template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; @@ -286,6 +326,11 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; // This is a separated namespace for manipulate SelectedRows typed // data. Like merge duplicated rows, adding two SelectedRows etc. // @@ -294,30 +339,30 @@ template struct SelectedRowsAddToTensor +template typename std::enable_if::value>::type elementwise_add_to( - phi::funcs::BlasT* blas, size_t data_len, - const T* in, T* out) { + phi::funcs::BlasT* blas, size_t data_len, const T* in, + T* out) { blas->AXPY(data_len, T(1.f), in, out); } -template +template typename std::enable_if::value>::type elementwise_add_to( - phi::funcs::BlasT* blas, size_t data_len, - const T* in, T* out) { + phi::funcs::BlasT* blas, size_t data_len, const T* in, + T* out) { for (size_t i = 0; i < data_len; i++) { out[i] += in[i]; } } -template +template typename std::enable_if::value>::type add_sparse_inputs(const std::vector& inputs, const std::unordered_map& rows_to_id, - int64_t input_width, - const platform::CPUDeviceContext& context, T* out_data) { + int64_t input_width, const DeviceContext& context, + T* out_data) { #ifndef PADDLE_WITH_MKLDNN - auto blas = phi::funcs::GetBlas(context); + auto blas = phi::funcs::GetBlas(context); #endif for (auto* input : inputs) { if (input->rows().size() == 0) { @@ -336,22 +381,22 @@ add_sparse_inputs(const std::vector& inputs, #else for (size_t i = 0; i < input_rows.size(); i++) { size_t out_i = rows_to_id.at(input_rows[i]); - elementwise_add_to(&blas, static_cast(input_width), - &input_data[i * input_width], - &out_data[out_i * input_width]); + elementwise_add_to( + &blas, static_cast(input_width), &input_data[i * input_width], + &out_data[out_i * input_width]); } #endif } } -template +template typename std::enable_if::value>::type add_sparse_inputs(const std::vector& inputs, const std::unordered_map& rows_to_id, - int64_t input_width, - const platform::CPUDeviceContext& context, T* out_data) { + int64_t input_width, const DeviceContext& context, + T* out_data) { VLOG(4) << "[CPU] add_sparse_inputs <" << typeid(T).name(); - auto blas = phi::funcs::GetBlas(context); + auto blas = phi::funcs::GetBlas(context); for (auto* input : inputs) { if (input->rows().size() == 0) { continue; @@ -361,16 +406,16 @@ add_sparse_inputs(const std::vector& inputs, for (size_t i = 0; i < input_rows.size(); i++) { size_t out_i = rows_to_id.at(input_rows[i]); - elementwise_add_to(&blas, static_cast(input_width), - &input_data[i * input_width], - &out_data[out_i * input_width]); + elementwise_add_to( + &blas, static_cast(input_width), &input_data[i * input_width], + &out_data[out_i * input_width]); } } } -template -struct MergeAdd { - phi::SelectedRows operator()(const platform::CPUDeviceContext& context, +template +struct MergeAddImpl { + phi::SelectedRows operator()(const DeviceContext& context, const phi::SelectedRows& input, const bool sorted_result = false) { phi::SelectedRows out; @@ -378,15 +423,14 @@ struct MergeAdd { return out; } - void operator()(const platform::CPUDeviceContext& context, - const phi::SelectedRows& input, phi::SelectedRows* output, - const bool sorted_result = false) { + void operator()(const DeviceContext& context, const phi::SelectedRows& input, + phi::SelectedRows* output, const bool sorted_result = false) { std::vector inputs; inputs.push_back(&input); (*this)(context, inputs, output, sorted_result); } - void operator()(const platform::CPUDeviceContext& context, + void operator()(const DeviceContext& context, const std::vector& inputs, phi::SelectedRows* output, const bool sorted_result = false) { if (inputs.size() == 0) { @@ -461,7 +505,7 @@ struct MergeAdd { out.set_rows(merge_rows); - phi::funcs::SetConstant constant_functor; + phi::funcs::SetConstant constant_functor; constant_functor(context, out.mutable_value(), static_cast(0.f)); std::unordered_map rows_to_id; @@ -469,11 +513,75 @@ struct MergeAdd { rows_to_id[merge_rows[i]] = i; } - add_sparse_inputs(inputs, rows_to_id, input_width, context, out_data); + add_sparse_inputs(inputs, rows_to_id, input_width, + context, out_data); } } }; +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + phi::SelectedRows operator()(const platform::CPUDeviceContext& context, + const phi::SelectedRows& input, + const bool sorted_result) { + return MergeAddImpl()(context, input, + sorted_result); + } + + void operator()(const platform::CPUDeviceContext& context, + const phi::SelectedRows& input, phi::SelectedRows* output, + const bool sorted_result) { + MergeAddImpl()(context, input, output, + sorted_result); + } + + void operator()(const platform::CPUDeviceContext& context, + const std::vector& inputs, + phi::SelectedRows* output, const bool sorted_result) { + MergeAddImpl()(context, inputs, output, + sorted_result); + } +}; + +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + phi::SelectedRows operator()(const phi::CPUContext& context, + const phi::SelectedRows& input, + const bool sorted_result) { + return MergeAddImpl()(context, input, sorted_result); + } + + void operator()(const phi::CPUContext& context, + const phi::SelectedRows& input, phi::SelectedRows* output, + const bool sorted_result) { + MergeAddImpl()(context, input, output, sorted_result); + } + + void operator()(const phi::CPUContext& context, + const std::vector& inputs, + phi::SelectedRows* output, const bool sorted_result) { + MergeAddImpl()(context, inputs, output, sorted_result); + } +}; + +#define TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(dtype) \ + template struct MergeAddImpl; \ + template struct MergeAddImpl; \ + template struct MergeAdd; \ + template struct MergeAdd; + +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(float) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(double) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int64_t) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::bfloat16) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::complex) +TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::complex) + #ifdef PADDLE_WITH_XPU template struct MergeAdd { @@ -714,17 +822,6 @@ struct MergeAverage { } }; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd>; -template struct MergeAdd>; -template struct MergeAdd; - #ifdef PADDLE_WITH_XPU template struct MergeAdd; #endif diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index a4678550cf7bd0..542d4c9784352e 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -174,12 +174,77 @@ struct SelectedRowsAddTensor { } }; +template +struct SelectedRowsAddTensor { + void operator()(const phi::GPUContext& context, + const phi::SelectedRows& input1, + const framework::Tensor& input2, framework::Tensor* output) { + auto in1_height = input1.height(); + auto in2_dims = input2.dims(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ( + in1_height, in2_dims[0], + platform::errors::InvalidArgument( + "The two inputs height must be equal." + "But recieved first input height = [%d], first input height = [%d]", + in1_height, in2_dims[0])); + PADDLE_ENFORCE_EQ( + in1_height, out_dims[0], + platform::errors::InvalidArgument( + "The input and output height must be equal." + "But recieved input height = [%d], output height = [%d]", + in1_height, out_dims[0])); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ( + in1_row_numel, input2.numel() / in1_height, + platform::errors::InvalidArgument( + "The two inputs width must be equal." + "But recieved first input width = [%d], second input width = [%d]", + in1_row_numel, input2.numel() / in1_height)); + PADDLE_ENFORCE_EQ( + in1_row_numel, output->numel() / in1_height, + platform::errors::InvalidArgument( + "The input and output width must be equal." + "But recieved input width = [%d], output width = [%d]", + in1_row_numel, output->numel() / in1_height)); + + auto* in1_data = in1_value.data(); + auto* in2_data = input2.data(); + auto* out_data = output->data(); + + phi::funcs::SetConstant functor; + functor(context, output, static_cast(0)); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(in1_rows.size(), 1); + paddle::framework::MixVector mixv_in1_rows(&in1_rows); + SelectedRowsAddTensorKernel< + T, block_size><<>>( + in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), out_data, + in1_row_numel); + + auto out_eigen = framework::EigenVector::Flatten(*output); + auto in2_eigen = framework::EigenVector::Flatten(input2); + out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen; + } +}; + template struct SelectedRowsAddTensor; template struct SelectedRowsAddTensor; template struct SelectedRowsAdd; template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; +template struct SelectedRowsAdd; +template struct SelectedRowsAddTensor; + template struct SelectedRowsAddTo { void operator()(const platform::CUDADeviceContext& context, @@ -285,12 +350,54 @@ struct SelectedRowsAddToTensor { } }; +template +struct SelectedRowsAddToTensor { + void operator()(const phi::GPUContext& context, + const phi::SelectedRows& input1, framework::Tensor* input2) { + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ( + in1_height, in2_dims[0], + platform::errors::InvalidArgument("The two inputs height must be equal." + "But recieved first input height = " + "[%d], second input height = [%d]", + in1_height, in2_dims[0])); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ( + in1_row_numel, input2->numel() / in1_height, + platform::errors::InvalidArgument( + "The two inputs width must be equal." + "But recieved first input width = [%d], second input width = [%d]", + in1_row_numel, input2->numel() / in1_height)); + + auto* in1_data = in1_value.data(); + auto* in2_data = input2->data(); + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(in1_rows.size(), 1); + paddle::framework::MixVector mixv_in1_rows(&in1_rows); + SelectedRowsAddToTensorKernel< + T, block_size><<>>( + in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), in2_data, + in1_row_numel); + } +}; + template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; namespace scatter { @@ -319,9 +426,9 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, } } -template -struct MergeAdd { - phi::SelectedRows operator()(const platform::CUDADeviceContext& context, +template +struct MergeAddImpl { + phi::SelectedRows operator()(const DeviceContext& context, const phi::SelectedRows& input, const bool sorted_result = false) { phi::SelectedRows out; @@ -329,9 +436,8 @@ struct MergeAdd { return out; } - void operator()(const platform::CUDADeviceContext& context, - const phi::SelectedRows& input, phi::SelectedRows* output, - const bool sorted_result = false) { + void operator()(const DeviceContext& context, const phi::SelectedRows& input, + phi::SelectedRows* output, const bool sorted_result = false) { framework::Vector input_rows(input.rows()); if (input_rows.size() == 0) { return; @@ -350,7 +456,7 @@ struct MergeAdd { phi::make_ddim({static_cast(merge_rows.size()), input_width}), context.GetPlace()); - phi::funcs::SetConstant constant_functor; + phi::funcs::SetConstant constant_functor; constant_functor(context, out.mutable_value(), static_cast(0)); auto* out_data = out.mutable_value()->data(); @@ -369,7 +475,7 @@ struct MergeAdd { mix_vector_out.CopyToCPU(); } - void operator()(const platform::CUDADeviceContext& context, + void operator()(const DeviceContext& context, const std::vector& inputs, phi::SelectedRows* output, const bool sorted_result = false) { if (inputs.size() == 0) { @@ -414,7 +520,7 @@ struct MergeAdd { phi::make_ddim({static_cast(merge_rows.size()), input_width}), context.GetPlace()); - phi::funcs::SetConstant constant_functor; + phi::funcs::SetConstant constant_functor; constant_functor(context, out.mutable_value(), static_cast(0)); auto* out_data = out.mutable_value()->data(); @@ -441,15 +547,69 @@ struct MergeAdd { } }; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd>; -template struct MergeAdd>; +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + phi::SelectedRows operator()(const platform::CUDADeviceContext& context, + const phi::SelectedRows& input, + const bool sorted_result) { + return MergeAddImpl()(context, input, + sorted_result); + } + + void operator()(const platform::CUDADeviceContext& context, + const phi::SelectedRows& input, phi::SelectedRows* output, + const bool sorted_result) { + MergeAddImpl()(context, input, output, + sorted_result); + } + + void operator()(const platform::CUDADeviceContext& context, + const std::vector& inputs, + phi::SelectedRows* output, const bool sorted_result) { + MergeAddImpl()(context, inputs, output, + sorted_result); + } +}; + +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + phi::SelectedRows operator()(const phi::GPUContext& context, + const phi::SelectedRows& input, + const bool sorted_result) { + return MergeAddImpl()(context, input, sorted_result); + } + + void operator()(const phi::GPUContext& context, + const phi::SelectedRows& input, phi::SelectedRows* output, + const bool sorted_result) { + MergeAddImpl()(context, input, output, sorted_result); + } + + void operator()(const phi::GPUContext& context, + const std::vector& inputs, + phi::SelectedRows* output, const bool sorted_result) { + MergeAddImpl()(context, inputs, output, sorted_result); + } +}; + +#define TEMPLATE_SPECIALIZED_FOR_MERGEADD(dtype) \ + template struct MergeAddImpl; \ + template struct MergeAddImpl; \ + template struct MergeAdd; \ + template struct MergeAdd; + +TEMPLATE_SPECIALIZED_FOR_MERGEADD(float) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(double) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(int) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(int64_t) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::float16) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::bfloat16) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex) +TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex) template __global__ void UpdateToTensorKernel(const T* selected_rows, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index d443b7bb2a0922..761d65ed36c295 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -11,7 +11,7 @@ set_property(GLOBAL PROPERTY PHI_KERNELS "") # [ 1. Common kernel compilation dependencies ] set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel) -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) diff --git a/paddle/phi/kernels/clip_grad_kernel.h b/paddle/phi/kernels/clip_grad_kernel.h new file mode 100644 index 00000000000000..8a7e5b99fd9248 --- /dev/null +++ b/paddle/phi/kernels/clip_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void ClipGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& min, + const Scalar& max, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/clip_kernel.h b/paddle/phi/kernels/clip_kernel.h new file mode 100644 index 00000000000000..c64566e41ef50b --- /dev/null +++ b/paddle/phi/kernels/clip_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void ClipDenseKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& min, + const Scalar& max, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/clip_grad_kernel.cc b/paddle/phi/kernels/cpu/clip_grad_kernel.cc new file mode 100644 index 00000000000000..5bf8226a303121 --- /dev/null +++ b/paddle/phi/kernels/cpu/clip_grad_kernel.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/clip_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_grad, + CPU, + ALL_LAYOUT, + phi::ClipGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/clip_kernel.cc b/paddle/phi/kernels/cpu/clip_kernel.cc new file mode 100644 index 00000000000000..21385be3da2fca --- /dev/null +++ b/paddle/phi/kernels/cpu/clip_kernel.cc @@ -0,0 +1,18 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/clip_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL( + clip, CPU, ALL_LAYOUT, phi::ClipDenseKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/clip_grad_kernel.cu b/paddle/phi/kernels/gpu/clip_grad_kernel.cu new file mode 100644 index 00000000000000..469abb6573097d --- /dev/null +++ b/paddle/phi/kernels/gpu/clip_grad_kernel.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/clip_grad_kernel.h" +#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_grad, + GPU, + ALL_LAYOUT, + phi::ClipGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/clip_kernel.cu b/paddle/phi/kernels/gpu/clip_kernel.cu new file mode 100644 index 00000000000000..38ac4d62805447 --- /dev/null +++ b/paddle/phi/kernels/gpu/clip_kernel.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/clip_kernel.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL(clip, + GPU, + ALL_LAYOUT, + phi::ClipDenseKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/clip_grad_kernel_impl.h b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h new file mode 100644 index 00000000000000..ac1254622997de --- /dev/null +++ b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/clip_kernel.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/transform.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#endif + +namespace phi { + +template +class ClipGradFunctor { + public: + explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T x, const T y) const { + return (y > min_ && y < max_) ? x : static_cast(0); + } + + private: + T min_; + T max_; +}; + +template +void ClipGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& min, + const Scalar& max, + DenseTensor* x_grad) { + auto max_ = max.to(); + auto min_ = min.to(); + +#if defined(__NVCC__) || defined(__HIPCC__) + std::vector ins = {&out_grad, &x}; + std::vector outs = {x_grad}; + auto functor = ClipGradFunctor(min_, max_); + dev_ctx.template Alloc(x_grad); + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); +#else + int64_t numel = out_grad.numel(); + auto* d_x_data = dev_ctx.template Alloc(x_grad); + const T* d_out_data = out_grad.data(); + const T* x_data = x.data(); + paddle::platform::Transform trans; + trans(dev_ctx, + d_out_data, + d_out_data + numel, + x_data, + d_x_data, + ClipGradFunctor(min_, max_)); +#endif +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/clip_kernel_impl.h b/paddle/phi/kernels/impl/clip_kernel_impl.h new file mode 100644 index 00000000000000..94066600d993c8 --- /dev/null +++ b/paddle/phi/kernels/impl/clip_kernel_impl.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/clip_kernel.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/transform.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#endif + +namespace phi { + +template +class ClipFunctor { + public: + explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T x) const { + return x < min_ ? min_ : x > max_ ? max_ : x; + } + + private: + T min_; + T max_; +}; + +template +void ClipDenseKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& min, + const Scalar& max, + DenseTensor* out) { + auto max_ = max.to(); + auto min_ = min.to(); + + PADDLE_ENFORCE_LE( + min_, + max_, + errors::InvalidArgument("max should be greater than or equal to min. " + "But received min = %f, max = %f", + static_cast(min_), + static_cast(max_))); + + T* out_data = dev_ctx.template Alloc(out); + // const T* x_data = x->data(); + // int64_t numel = x->numel(); + const T* x_data = x.data(); + int64_t numel = x.numel(); + if (paddle::platform::is_gpu_place(dev_ctx.GetPlace())) { +#if defined(__NVCC__) || defined(__HIPCC__) + std::vector ins = {&x}; + std::vector outs = {out}; + auto functor = ClipFunctor(min_, max_); + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); +#endif + } else { + paddle::platform::Transform trans; + trans( + dev_ctx, x_data, x_data + numel, out_data, ClipFunctor(min_, max_)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/clip_kernel.cc b/paddle/phi/kernels/selected_rows/clip_kernel.cc new file mode 100644 index 00000000000000..e2ccd790cdbb30 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/clip_kernel.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/selected_rows/clip_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/clip_kernel.h" + +PD_REGISTER_KERNEL(clip_dense_param_sparse_grad, + CPU, + ALL_LAYOUT, + phi::sr::ClipSparseKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/selected_rows/clip_kernel.cu b/paddle/phi/kernels/selected_rows/clip_kernel.cu new file mode 100644 index 00000000000000..62bcff54f2b048 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/clip_kernel.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/clip_kernel.h" +#include "paddle/phi/kernels/selected_rows/clip_kernel.h" + +PD_REGISTER_KERNEL(clip_dense_param_sparse_grad, + GPU, + ALL_LAYOUT, + phi::sr::ClipSparseKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/selected_rows/clip_kernel.h b/paddle/phi/kernels/selected_rows/clip_kernel.h new file mode 100644 index 00000000000000..a33564ce455116 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/clip_kernel.h @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" + +namespace phi { +namespace sr { + +template +void ClipSparseKernel(const Context& dev_ctx, + const SelectedRows& x, + const Scalar& min, + const Scalar& max, + SelectedRows* out) { + auto max_ = max.to(); + auto min_ = min.to(); + + PADDLE_ENFORCE_LE( + min_, + max_, + errors::InvalidArgument("max should be greater than or equal to min. " + "But received min = %f, max = %f", + static_cast(min_), + static_cast(max_))); + + PADDLE_ENFORCE_NE(&x, + out, + errors::InvalidArgument("Inplace clip is not allowed " + "when x is SelectedRows")); + paddle::operators::math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, x, out); + auto* out_tensor = out->mutable_value(); + auto* out_data = out_tensor->data(); + int64_t numel = out_tensor->numel(); + paddle::platform::Transform trans; + trans(dev_ctx, + out_data, + out_data + numel, + out_data, + ClipFunctor(min_, max_)); +} +} // namespace sr +} // namespace phi diff --git a/paddle/phi/ops/compat/clip_sig.cc b/paddle/phi/ops/compat/clip_sig.cc new file mode 100644 index 00000000000000..5fee3e5170f1b1 --- /dev/null +++ b/paddle/phi/ops/compat/clip_sig.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/utils/small_vector.h" + +namespace phi { + +KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { + paddle::SmallVector attr_names; + attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min"); + attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max"); + if (ctx.IsDenseTensorInput("X")) { + return KernelSignature("clip", {"X"}, std::move(attr_names), {"Out"}); + } else if (ctx.IsSelectedRowsInput("X")) { + return KernelSignature( + "clip_dense_param_sparse_grad", {"X"}, std::move(attr_names), {"Out"}); + } + + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + paddle::SmallVector attr_names; + attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min"); + attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max"); + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + std::move(attr_names), + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(clip, phi::ClipOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(clip_grad, phi::ClipGradOpArgumentMapping); From 56f278be59f90e57afcd5ebe3ede9e3365ef8d83 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Thu, 17 Mar 2022 08:49:26 +0000 Subject: [PATCH 2/6] fix as review --- paddle/fluid/operators/clip_op.cc | 26 ++++---- paddle/fluid/operators/clip_op_npu.cc | 26 ++++---- paddle/fluid/operators/clip_op_xpu.cc | 26 ++++---- paddle/phi/kernels/clip_kernel.h | 10 +-- paddle/phi/kernels/cpu/clip_grad_kernel.cc | 23 ++++--- paddle/phi/kernels/cpu/clip_kernel.cc | 25 ++++---- paddle/phi/kernels/gpu/clip_grad_kernel.cu | 26 ++++---- paddle/phi/kernels/gpu/clip_kernel.cu | 25 ++++---- .../phi/kernels/impl/clip_grad_kernel_impl.h | 5 +- paddle/phi/kernels/impl/clip_kernel_impl.h | 15 ++--- .../phi/kernels/selected_rows/clip_kernel.cc | 64 +++++++++++++++---- .../phi/kernels/selected_rows/clip_kernel.cu | 28 ++++---- .../phi/kernels/selected_rows/clip_kernel.h | 29 +-------- paddle/phi/ops/compat/clip_sig.cc | 61 +++++++++++++++--- 14 files changed, 230 insertions(+), 159 deletions(-) diff --git a/paddle/fluid/operators/clip_op.cc b/paddle/fluid/operators/clip_op.cc index 3f8755a6203b51..6e898d31663fac 100644 --- a/paddle/fluid/operators/clip_op.cc +++ b/paddle/fluid/operators/clip_op.cc @@ -1,16 +1,16 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include "paddle/fluid/framework/infershape_utils.h" diff --git a/paddle/fluid/operators/clip_op_npu.cc b/paddle/fluid/operators/clip_op_npu.cc index 32b17f94253a71..17d7ad97965040 100644 --- a/paddle/fluid/operators/clip_op_npu.cc +++ b/paddle/fluid/operators/clip_op_npu.cc @@ -1,16 +1,16 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" diff --git a/paddle/fluid/operators/clip_op_xpu.cc b/paddle/fluid/operators/clip_op_xpu.cc index b11a6717c63bb6..c551312837274f 100644 --- a/paddle/fluid/operators/clip_op_xpu.cc +++ b/paddle/fluid/operators/clip_op_xpu.cc @@ -1,16 +1,16 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #ifdef PADDLE_WITH_XPU diff --git a/paddle/phi/kernels/clip_kernel.h b/paddle/phi/kernels/clip_kernel.h index c64566e41ef50b..14ac8342e03bcf 100644 --- a/paddle/phi/kernels/clip_kernel.h +++ b/paddle/phi/kernels/clip_kernel.h @@ -22,10 +22,10 @@ namespace phi { template -void ClipDenseKernel(const Context& dev_ctx, - const DenseTensor& x, - const Scalar& min, - const Scalar& max, - DenseTensor* out); +void ClipKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& min, + const Scalar& max, + DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/cpu/clip_grad_kernel.cc b/paddle/phi/kernels/cpu/clip_grad_kernel.cc index 5bf8226a303121..bccdc0746d51ca 100644 --- a/paddle/phi/kernels/cpu/clip_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/clip_grad_kernel.cc @@ -1,13 +1,16 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/phi/kernels/clip_grad_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" diff --git a/paddle/phi/kernels/cpu/clip_kernel.cc b/paddle/phi/kernels/cpu/clip_kernel.cc index 21385be3da2fca..5fd9aea966f8d2 100644 --- a/paddle/phi/kernels/cpu/clip_kernel.cc +++ b/paddle/phi/kernels/cpu/clip_kernel.cc @@ -1,13 +1,16 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/phi/kernels/clip_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" @@ -15,4 +18,4 @@ limitations under the License. */ #include "paddle/phi/kernels/impl/clip_kernel_impl.h" PD_REGISTER_KERNEL( - clip, CPU, ALL_LAYOUT, phi::ClipDenseKernel, float, double, int, int64_t) {} + clip, CPU, ALL_LAYOUT, phi::ClipKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/clip_grad_kernel.cu b/paddle/phi/kernels/gpu/clip_grad_kernel.cu index 469abb6573097d..b76086be648877 100644 --- a/paddle/phi/kernels/gpu/clip_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_grad_kernel.cu @@ -1,17 +1,21 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/clip_grad_kernel.h" #include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h" PD_REGISTER_KERNEL(clip_grad, diff --git a/paddle/phi/kernels/gpu/clip_kernel.cu b/paddle/phi/kernels/gpu/clip_kernel.cu index 38ac4d62805447..9295b8b37a01ff 100644 --- a/paddle/phi/kernels/gpu/clip_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_kernel.cu @@ -1,13 +1,16 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/float16.h" @@ -18,7 +21,7 @@ limitations under the License. */ PD_REGISTER_KERNEL(clip, GPU, ALL_LAYOUT, - phi::ClipDenseKernel, + phi::ClipKernel, float, double, int, diff --git a/paddle/phi/kernels/impl/clip_grad_kernel_impl.h b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h index ac1254622997de..2235156e37bb13 100644 --- a/paddle/phi/kernels/impl/clip_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h @@ -21,7 +21,7 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/transform.h" #if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" #endif namespace phi { @@ -54,8 +54,7 @@ void ClipGradKernel(const Context& dev_ctx, std::vector outs = {x_grad}; auto functor = ClipGradFunctor(min_, max_); dev_ctx.template Alloc(x_grad); - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - dev_ctx, ins, &outs, functor); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); #else int64_t numel = out_grad.numel(); auto* d_x_data = dev_ctx.template Alloc(x_grad); diff --git a/paddle/phi/kernels/impl/clip_kernel_impl.h b/paddle/phi/kernels/impl/clip_kernel_impl.h index 94066600d993c8..e95052d41d9d6a 100644 --- a/paddle/phi/kernels/impl/clip_kernel_impl.h +++ b/paddle/phi/kernels/impl/clip_kernel_impl.h @@ -21,7 +21,7 @@ #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/transform.h" #if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" #endif namespace phi { @@ -40,11 +40,11 @@ class ClipFunctor { }; template -void ClipDenseKernel(const Context& dev_ctx, - const DenseTensor& x, - const Scalar& min, - const Scalar& max, - DenseTensor* out) { +void ClipKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& min, + const Scalar& max, + DenseTensor* out) { auto max_ = max.to(); auto min_ = min.to(); @@ -66,8 +66,7 @@ void ClipDenseKernel(const Context& dev_ctx, std::vector ins = {&x}; std::vector outs = {out}; auto functor = ClipFunctor(min_, max_); - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - dev_ctx, ins, &outs, functor); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); #endif } else { paddle::platform::Transform trans; diff --git a/paddle/phi/kernels/selected_rows/clip_kernel.cc b/paddle/phi/kernels/selected_rows/clip_kernel.cc index e2ccd790cdbb30..905db70b4319d3 100644 --- a/paddle/phi/kernels/selected_rows/clip_kernel.cc +++ b/paddle/phi/kernels/selected_rows/clip_kernel.cc @@ -1,20 +1,62 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/phi/kernels/selected_rows/clip_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/clip_kernel.h" -PD_REGISTER_KERNEL(clip_dense_param_sparse_grad, +namespace phi { +namespace sr { + +template +void ClipSparseKernel(const Context& dev_ctx, + const SelectedRows& x, + const Scalar& min, + const Scalar& max, + SelectedRows* out) { + auto max_ = max.to(); + auto min_ = min.to(); + + PADDLE_ENFORCE_LE( + min_, + max_, + errors::InvalidArgument("max should be greater than or equal to min. " + "But received min = %f, max = %f", + static_cast(min_), + static_cast(max_))); + + PADDLE_ENFORCE_NE(&x, + out, + errors::InvalidArgument("Inplace clip is not allowed " + "when x is SelectedRows")); + paddle::operators::math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, x, out); + auto* out_tensor = out->mutable_value(); + auto* out_data = out_tensor->data(); + int64_t numel = out_tensor->numel(); + paddle::platform::Transform trans; + trans(dev_ctx, + out_data, + out_data + numel, + out_data, + ClipFunctor(min_, max_)); +} +} // namespace sr +} // namespace phi + +PD_REGISTER_KERNEL(clip_sr, CPU, ALL_LAYOUT, phi::sr::ClipSparseKernel, diff --git a/paddle/phi/kernels/selected_rows/clip_kernel.cu b/paddle/phi/kernels/selected_rows/clip_kernel.cu index 62bcff54f2b048..eb8f6dba5c17c9 100644 --- a/paddle/phi/kernels/selected_rows/clip_kernel.cu +++ b/paddle/phi/kernels/selected_rows/clip_kernel.cu @@ -1,21 +1,25 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/clip_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/clip_kernel.h" -#include "paddle/phi/kernels/selected_rows/clip_kernel.h" -PD_REGISTER_KERNEL(clip_dense_param_sparse_grad, +PD_REGISTER_KERNEL(clip_sr, GPU, ALL_LAYOUT, phi::sr::ClipSparseKernel, diff --git a/paddle/phi/kernels/selected_rows/clip_kernel.h b/paddle/phi/kernels/selected_rows/clip_kernel.h index a33564ce455116..ec56d92c513ea2 100644 --- a/paddle/phi/kernels/selected_rows/clip_kernel.h +++ b/paddle/phi/kernels/selected_rows/clip_kernel.h @@ -29,33 +29,6 @@ void ClipSparseKernel(const Context& dev_ctx, const SelectedRows& x, const Scalar& min, const Scalar& max, - SelectedRows* out) { - auto max_ = max.to(); - auto min_ = min.to(); - - PADDLE_ENFORCE_LE( - min_, - max_, - errors::InvalidArgument("max should be greater than or equal to min. " - "But received min = %f, max = %f", - static_cast(min_), - static_cast(max_))); - - PADDLE_ENFORCE_NE(&x, - out, - errors::InvalidArgument("Inplace clip is not allowed " - "when x is SelectedRows")); - paddle::operators::math::scatter::MergeAdd merge_func; - merge_func(dev_ctx, x, out); - auto* out_tensor = out->mutable_value(); - auto* out_data = out_tensor->data(); - int64_t numel = out_tensor->numel(); - paddle::platform::Transform trans; - trans(dev_ctx, - out_data, - out_data + numel, - out_data, - ClipFunctor(min_, max_)); -} + SelectedRows* out); } // namespace sr } // namespace phi diff --git a/paddle/phi/ops/compat/clip_sig.cc b/paddle/phi/ops/compat/clip_sig.cc index 5fee3e5170f1b1..78fa6c36a51492 100644 --- a/paddle/phi/ops/compat/clip_sig.cc +++ b/paddle/phi/ops/compat/clip_sig.cc @@ -22,23 +22,64 @@ KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min"); attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max"); if (ctx.IsDenseTensorInput("X")) { - return KernelSignature("clip", {"X"}, std::move(attr_names), {"Out"}); + if (ctx.HasInput("Min")) { + if (ctx.HasInput("Max")) { + return KernelSignature("clip", {"X"}, {"Min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip", {"X"}, {"Min", "max"}, {"Out"}); + } + } else { + if (ctx.HasInput("Max")) { + return KernelSignature("clip", {"X"}, {"min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip", {"X"}, {"min", "max"}, {"Out"}); + } + } } else if (ctx.IsSelectedRowsInput("X")) { - return KernelSignature( - "clip_dense_param_sparse_grad", {"X"}, std::move(attr_names), {"Out"}); + if (ctx.HasInput("Min")) { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_sr", {"X"}, {"Min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip_sr", {"X"}, {"Min", "max"}, {"Out"}); + } + } else { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_sr", {"X"}, {"min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip_sr", {"X"}, {"min", "max"}, {"Out"}); + } + } } return KernelSignature("unregistered", {}, {}, {}); } KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - paddle::SmallVector attr_names; - attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min"); - attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max"); - return KernelSignature("clip_grad", - {"X", GradVarName("Out")}, - std::move(attr_names), - {GradVarName("X")}); + if (ctx.HasInput("Min")) { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"Min", "Max"}, + {GradVarName("X")}); + } else { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"Min", "max"}, + {GradVarName("X")}); + } + } else { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"min", "Max"}, + {GradVarName("X")}); + } else { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"min", "max"}, + {GradVarName("X")}); + } + } } } // namespace phi From 3e2c7bc724a89b3a9d518f01039a5ed985b19341 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Tue, 22 Mar 2022 07:46:35 +0000 Subject: [PATCH 3/6] update hierarchical_sigmoid_kernel.cc --- paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc b/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc index 096a54f9fb263d..4c4f1aa125a339 100644 --- a/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc +++ b/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc @@ -14,7 +14,6 @@ #include "paddle/phi/kernels/hierarchical_sigmoid_kernel.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/platform/transform.h" #include "paddle/phi/backends/cpu/cpu_context.h" @@ -22,6 +21,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/math_function_impl.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" namespace phi { @@ -92,8 +92,7 @@ void HierarchicalSigmoidKernel(const Context& ctx, pre_out_data, pre_out_data + pre_out->numel(), pre_out_data, - paddle::operators::ClipFunctor(static_cast(-40.0), - static_cast(40.0))); + ClipFunctor(static_cast(-40.0), static_cast(40.0))); bit_code->Sum(*pre_out, out, static_cast(-1)); // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); From de681a963eb4ce3681cd6d929d76876cbd9576df Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Sun, 27 Mar 2022 15:33:41 +0000 Subject: [PATCH 4/6] update selected_rows --- .../operators/math/selected_rows_functor.cc | 45 -------- .../operators/math/selected_rows_functor.cu | 107 ------------------ paddle/phi/kernels/CMakeLists.txt | 2 +- 3 files changed, 1 insertion(+), 153 deletions(-) diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 0ca2529f132a0b..977b8dd21d7da7 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -279,46 +279,6 @@ struct SelectedRowsAddToTensor { } }; -template -struct SelectedRowsAddToTensor { - void operator()(const phi::CPUContext& context, - const phi::SelectedRows& input1, framework::Tensor* input2) { - if (UNLIKELY(input1.rows().size() == 0)) { - LOG(WARNING) << "input selected rows is empty!"; - return; - } - auto in1_height = input1.height(); - auto in2_dims = input2->dims(); - PADDLE_ENFORCE_EQ( - in1_height, in2_dims[0], - platform::errors::InvalidArgument("The two inputs height must be equal." - "But recieved first input height = " - "[%d], second input height = [%d]", - in1_height, in2_dims[0])); - - auto& in1_value = input1.value(); - auto& in1_rows = input1.rows(); - - int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); - PADDLE_ENFORCE_EQ( - in1_row_numel, input2->numel() / in1_height, - platform::errors::InvalidArgument( - "The two inputs width must be equal." - "But recieved first input width = [%d], second input width = [%d]", - in1_row_numel, input2->numel() / in1_height)); - - auto* in1_data = in1_value.data(); - auto* input2_data = input2->data(); - - for (size_t i = 0; i < in1_rows.size(); i++) { - for (int64_t j = 0; j < in1_row_numel; j++) { - input2_data[in1_rows[i] * in1_row_numel + j] += - in1_data[i * in1_row_numel + j]; - } - } - } -}; - template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; @@ -326,11 +286,6 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; // This is a separated namespace for manipulate SelectedRows typed // data. Like merge duplicated rows, adding two SelectedRows etc. // diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 542d4c9784352e..16ef013f689c4f 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -174,77 +174,12 @@ struct SelectedRowsAddTensor { } }; -template -struct SelectedRowsAddTensor { - void operator()(const phi::GPUContext& context, - const phi::SelectedRows& input1, - const framework::Tensor& input2, framework::Tensor* output) { - auto in1_height = input1.height(); - auto in2_dims = input2.dims(); - auto out_dims = output->dims(); - PADDLE_ENFORCE_EQ( - in1_height, in2_dims[0], - platform::errors::InvalidArgument( - "The two inputs height must be equal." - "But recieved first input height = [%d], first input height = [%d]", - in1_height, in2_dims[0])); - PADDLE_ENFORCE_EQ( - in1_height, out_dims[0], - platform::errors::InvalidArgument( - "The input and output height must be equal." - "But recieved input height = [%d], output height = [%d]", - in1_height, out_dims[0])); - - auto& in1_value = input1.value(); - auto& in1_rows = input1.rows(); - - int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); - PADDLE_ENFORCE_EQ( - in1_row_numel, input2.numel() / in1_height, - platform::errors::InvalidArgument( - "The two inputs width must be equal." - "But recieved first input width = [%d], second input width = [%d]", - in1_row_numel, input2.numel() / in1_height)); - PADDLE_ENFORCE_EQ( - in1_row_numel, output->numel() / in1_height, - platform::errors::InvalidArgument( - "The input and output width must be equal." - "But recieved input width = [%d], output width = [%d]", - in1_row_numel, output->numel() / in1_height)); - - auto* in1_data = in1_value.data(); - auto* in2_data = input2.data(); - auto* out_data = output->data(); - - phi::funcs::SetConstant functor; - functor(context, output, static_cast(0)); - - const int block_size = 256; - dim3 threads(block_size, 1); - dim3 grid(in1_rows.size(), 1); - paddle::framework::MixVector mixv_in1_rows(&in1_rows); - SelectedRowsAddTensorKernel< - T, block_size><<>>( - in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), out_data, - in1_row_numel); - - auto out_eigen = framework::EigenVector::Flatten(*output); - auto in2_eigen = framework::EigenVector::Flatten(input2); - out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen; - } -}; - template struct SelectedRowsAddTensor; template struct SelectedRowsAddTensor; template struct SelectedRowsAdd; template struct SelectedRowsAddTensor; -template struct SelectedRowsAddTensor; -template struct SelectedRowsAddTensor; -template struct SelectedRowsAdd; -template struct SelectedRowsAddTensor; - template struct SelectedRowsAddTo { void operator()(const platform::CUDADeviceContext& context, @@ -350,54 +285,12 @@ struct SelectedRowsAddToTensor { } }; -template -struct SelectedRowsAddToTensor { - void operator()(const phi::GPUContext& context, - const phi::SelectedRows& input1, framework::Tensor* input2) { - auto in1_height = input1.height(); - auto in2_dims = input2->dims(); - PADDLE_ENFORCE_EQ( - in1_height, in2_dims[0], - platform::errors::InvalidArgument("The two inputs height must be equal." - "But recieved first input height = " - "[%d], second input height = [%d]", - in1_height, in2_dims[0])); - - auto& in1_value = input1.value(); - auto& in1_rows = input1.rows(); - - int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); - PADDLE_ENFORCE_EQ( - in1_row_numel, input2->numel() / in1_height, - platform::errors::InvalidArgument( - "The two inputs width must be equal." - "But recieved first input width = [%d], second input width = [%d]", - in1_row_numel, input2->numel() / in1_height)); - - auto* in1_data = in1_value.data(); - auto* in2_data = input2->data(); - const int block_size = 256; - dim3 threads(block_size, 1); - dim3 grid(in1_rows.size(), 1); - paddle::framework::MixVector mixv_in1_rows(&in1_rows); - SelectedRowsAddToTensorKernel< - T, block_size><<>>( - in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), in2_data, - in1_row_numel); - } -}; - template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; -template struct SelectedRowsAddToTensor; namespace scatter { diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index ac1cc1ccf961fd..c752387ee35ba4 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -11,7 +11,7 @@ set_property(GLOBAL PROPERTY PHI_KERNELS "") # [ 1. Common kernel compilation dependencies ] set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel) -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) From 193122ca20d8b3138977a78685def54dc1226b69 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Mon, 28 Mar 2022 08:57:48 +0000 Subject: [PATCH 5/6] update clip_kernel.cu --- paddle/phi/kernels/gpu/clip_kernel.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/clip_kernel.cu b/paddle/phi/kernels/gpu/clip_kernel.cu index 9295b8b37a01ff..9e0050db7fdbf1 100644 --- a/paddle/phi/kernels/gpu/clip_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_kernel.cu @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/clip_kernel.h" + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/clip_kernel.h" #include "paddle/phi/kernels/impl/clip_kernel_impl.h" PD_REGISTER_KERNEL(clip, From fff4cf59043766eeb6e5a615e879192bf84ca441 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Wed, 30 Mar 2022 03:38:49 +0000 Subject: [PATCH 6/6] fix as review --- .../phi/kernels/impl/clip_grad_kernel_impl.h | 3 +- paddle/phi/kernels/impl/clip_kernel_impl.h | 3 +- .../kernels/selected_rows/cpu/clip_kernel.cc | 28 +++++++++++++++++++ .../selected_rows/{ => gpu}/clip_kernel.cu | 2 +- .../clip_kernel_impl.h} | 20 ++++++------- 5 files changed, 41 insertions(+), 15 deletions(-) create mode 100644 paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc rename paddle/phi/kernels/selected_rows/{ => gpu}/clip_kernel.cu (94%) rename paddle/phi/kernels/selected_rows/{clip_kernel.cc => impl/clip_kernel_impl.h} (83%) diff --git a/paddle/phi/kernels/impl/clip_grad_kernel_impl.h b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h index 2235156e37bb13..7ce86492327bac 100644 --- a/paddle/phi/kernels/impl/clip_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h @@ -14,9 +14,10 @@ #pragma once +#include "paddle/phi/kernels/clip_kernel.h" + #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/clip_kernel.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/transform.h" diff --git a/paddle/phi/kernels/impl/clip_kernel_impl.h b/paddle/phi/kernels/impl/clip_kernel_impl.h index e95052d41d9d6a..17c04c31a598af 100644 --- a/paddle/phi/kernels/impl/clip_kernel_impl.h +++ b/paddle/phi/kernels/impl/clip_kernel_impl.h @@ -14,9 +14,10 @@ #pragma once +#include "paddle/phi/kernels/clip_kernel.h" + #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/clip_kernel.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/transform.h" diff --git a/paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc new file mode 100644 index 00000000000000..0098bf13f2b2f1 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/clip_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_sr, + CPU, + ALL_LAYOUT, + phi::sr::ClipSparseKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/selected_rows/clip_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/clip_kernel.cu similarity index 94% rename from paddle/phi/kernels/selected_rows/clip_kernel.cu rename to paddle/phi/kernels/selected_rows/gpu/clip_kernel.cu index eb8f6dba5c17c9..a8d659559e19e5 100644 --- a/paddle/phi/kernels/selected_rows/clip_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/clip_kernel.cu @@ -17,7 +17,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/clip_kernel.h" +#include "paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h" PD_REGISTER_KERNEL(clip_sr, GPU, diff --git a/paddle/phi/kernels/selected_rows/clip_kernel.cc b/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h similarity index 83% rename from paddle/phi/kernels/selected_rows/clip_kernel.cc rename to paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h index 905db70b4319d3..1d95e633b93a6e 100644 --- a/paddle/phi/kernels/selected_rows/clip_kernel.cc +++ b/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h @@ -12,10 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once + #include "paddle/phi/kernels/selected_rows/clip_kernel.h" -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/clip_kernel.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/selected_rows.h" namespace phi { namespace sr { @@ -55,12 +60,3 @@ void ClipSparseKernel(const Context& dev_ctx, } } // namespace sr } // namespace phi - -PD_REGISTER_KERNEL(clip_sr, - CPU, - ALL_LAYOUT, - phi::sr::ClipSparseKernel, - float, - double, - int, - int64_t) {}