From a42fe9ce2c31f4f673c6b0d952d0f55aff79d43f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 4 Mar 2022 09:49:15 +0000 Subject: [PATCH 1/3] save code --- paddle/fluid/framework/infershape_utils.cc | 54 ++ paddle/fluid/framework/operator.cc | 56 ++ paddle/fluid/imperative/prepared_operator.h | 54 ++ paddle/phi/core/kernel_utils.h | 1 + paddle/phi/kernels/cpu/set_value_kernel.cc | 38 ++ paddle/phi/kernels/gpu/set_value_kernel.cu | 38 ++ .../phi/kernels/impl/set_value_kernel_impl.h | 330 ++++++++++ paddle/phi/kernels/set_value_kernel.h | 49 ++ paddle/phi/ops/compat/set_value_sig.cc | 616 ++++++++++++++++++ 9 files changed, 1236 insertions(+) create mode 100644 paddle/phi/kernels/cpu/set_value_kernel.cc create mode 100644 paddle/phi/kernels/gpu/set_value_kernel.cu create mode 100644 paddle/phi/kernels/impl/set_value_kernel_impl.h create mode 100644 paddle/phi/kernels/set_value_kernel.h create mode 100644 paddle/phi/ops/compat/set_value_sig.cc diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 57fb68e80427af..c6c770849eaf04 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -433,6 +433,60 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name, infershape_input.size())); } } + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + auto& attr = attr_reader.GetAttr(attr_name); + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector when " + "construct InferMetaContext.", + attr_name)); + } } else if (ctx->HasAttr(attr_name)) { // Emplace Back Attr according to the type of attr. auto& attr = attr_reader.GetAttr(attr_name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index b12ad552aba6e6..0ae121425fb911 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1186,7 +1186,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // phase phi::KernelKey pt_kernel_key; std::string pt_kernel_name; + VLOG(1) << "########### " << phi::KernelFactory::Instance(); if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) { + VLOG(1) << "######### HasCompatiblePhiKernel"; if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { pt_kernel_signature_.reset( new KernelSignature(std::move(GetExpectedPhiKernelArgs(exe_ctx)))); @@ -2178,6 +2180,60 @@ void OperatorWithKernel::BuildPhiKernelContext( std::move(experimental::MakePhiScalarFromVar(*ins_vector.front()))); } + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + auto& attr = Attrs().at(attr_names[i]); + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector when " + "construct KernelContext.", + attr_names[i])); + } } else { // TODO(chenweihang): support other attrs later auto& attr = Attrs().at(attr_names[i]); diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 3b5762720e7fb4..88e2ecfea5b731 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -406,6 +406,60 @@ void BuildDygraphPhiKernelContext( experimental::MakePhiScalarFromVar(ins_vector[0]->Var()))); } + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector when " + "construct KernelContext.", + attr_names[i])); + } } else { // TODO(chenweihang): support other attrs later auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index e5de5e2b49ebb2..71b9afa6b0b2c6 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -250,6 +250,7 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); /* Output Helpers */ diff --git a/paddle/phi/kernels/cpu/set_value_kernel.cc b/paddle/phi/kernels/cpu/set_value_kernel.cc new file mode 100644 index 00000000000000..dcf278cd94e651 --- /dev/null +++ b/paddle/phi/kernels/cpu/set_value_kernel.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/set_value_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/set_value_kernel_impl.h" + +PD_REGISTER_KERNEL(set_value, + CPU, + ALL_LAYOUT, + phi::SetValueKernel, + float, + double, + int, + int64_t, + bool) {} +PD_REGISTER_KERNEL(set_value_with_tensor, + CPU, + ALL_LAYOUT, + phi::SetTensorValueKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/gpu/set_value_kernel.cu b/paddle/phi/kernels/gpu/set_value_kernel.cu new file mode 100644 index 00000000000000..f788da010b6827 --- /dev/null +++ b/paddle/phi/kernels/gpu/set_value_kernel.cu @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/set_value_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/set_value_kernel_impl.h" + +PD_REGISTER_KERNEL(set_value, + GPU, + ALL_LAYOUT, + phi::SetValueKernel, + float, + double, + int, + int64_t, + bool) {} +PD_REGISTER_KERNEL(set_value_with_tensor, + GPU, + ALL_LAYOUT, + phi::SetTensorValueKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/impl/set_value_kernel_impl.h b/paddle/phi/kernels/impl/set_value_kernel_impl.h new file mode 100644 index 00000000000000..5d98cf5a8e8516 --- /dev/null +++ b/paddle/phi/kernels/impl/set_value_kernel_impl.h @@ -0,0 +1,330 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/slice_utils.h" + +namespace phi { + +// check whether the tensor with dimension of second can assign to the +// tensor with dimension of first +inline void CheckIsDimsMatch(const DDim first, const DDim second) { + int ignore_axis1 = 0, ignore_axis2 = 0; + for (; ignore_axis1 < first.size(); ++ignore_axis1) { + if (first[ignore_axis1] != 1) { + break; + } + } + for (; ignore_axis2 < second.size(); ++ignore_axis2) { + if (second[ignore_axis2] != 1) { + break; + } + } + + if (second.size() == ignore_axis2) { + // second tensor has only one value + return; + } + + if (first.size() - ignore_axis1 >= second.size() - ignore_axis2) { + auto idx1 = first.size() - 1; + auto idx2 = second.size() - 1; + bool is_match = true; + for (; idx2 >= ignore_axis2; idx2--) { + if (first[idx1--] != second[idx2] && second[idx2] != 1) { + is_match = false; + break; + } + } + if (is_match) { + return; + } + } + PADDLE_THROW(errors::InvalidArgument( + "The shape of tensor assigned value must match the shape " + "of target shape: %d, but now shape is %d.", + second.to_str(), + first.to_str())); +} + +template +void SetValueImpl(const Context& dev_ctx, + const DenseTensor& in, + const DenseTensor& value, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* out) { + auto in_dims = in.dims(); + std::vector starts_local = starts.GetData(); + std::vector ends_local = ends.GetData(); + std::vector steps_local = steps.GetData(); + paddle::operators::CheckAndUpdateSliceAttrs( + in_dims, axes, &starts_local, &ends_local, &steps_local); + auto slice_dims = + GetSliceDims(in_dims, axes, starts_local, ends_local, &steps_local); + auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes); + + auto slice_dims_for_assign = decrease_slice_dims; + if (!none_axes.empty()) { + std::vector slice_dims_with_none; + + size_t none_axes_cur = 0, decrease_axes_cur = 0; + for (int i = 0; i < slice_dims.size(); ++i) { + while (none_axes_cur < none_axes.size() && + none_axes[none_axes_cur] <= i) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + if (decrease_axes_cur < decrease_axes.size() && + decrease_axes[decrease_axes_cur] == i) { + decrease_axes_cur++; + } else { + slice_dims_with_none.push_back(slice_dims[i]); + } + } + while (none_axes_cur < none_axes.size()) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + + slice_dims_for_assign = phi::make_ddim(slice_dims_with_none); + } + + auto place = dev_ctx.GetPlace(); + auto& eigen_place = *dev_ctx.eigen_device(); + + // Here copy data from input to avoid data loss at PE and Graph level. + // TODO(liym27): Speed up in the future version. + // - Q: Why don't call ShareDataWith to speed up? + // - A: Because it's not supported to ShareDataWith on OP's input and output + // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP + // - Q: Why don't delete Input, after all, the input and output are the same + // Tensor at program level? + // - A: If deleting Input, the graph will be complex, such as there will + // be two ops points to the output in graph: op1 -> output <- set_value. + // In this case, we have to find a way to handle the running order of + // set_value is what we want. + Copy(dev_ctx, in, place, false, out); + + Tensor slice_tensor = + Empty(dev_ctx, ScalarArray{slice_dims.Get(), slice_dims.size()}); + Tensor pad_tensor = + Empty(dev_ctx, ScalarArray{in_dims.Get(), in_dims.size()}); + + auto pad_e = framework::EigenTensor::From(pad_tensor, in_dims); + auto out_e = framework::EigenTensor::From(*out); + auto slice_e = + framework::EigenTensor::From(slice_tensor, slice_dims); + + // Step 1: Set the value of out at `_index` to zero + slice_e.device(eigen_place) = slice_e.constant(T(0)); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); + + for (size_t i = 0; i < D; ++i) { + starts_indices[i] = 0; + ends_indices[i] = slice_dims[i]; + strides_indices[i] = 1; + } + for (size_t i = 0; i < axes.size(); i++) { + int axis_index = axes[i]; + starts_indices[axis_index] = starts_local[i]; + ends_indices[axis_index] = ends_local[i]; + strides_indices[axis_index] = steps_local[i]; + if (starts_local[i] == + ends_local[i]) { // slice is empty, data will not be changed + return; + } + } + + out_e.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(eigen_place) = slice_e; + + // Step 2: Set a tensor with the same shape as out tensor. And its data at + // '_index' is the same as value, and data out of '_index' to zero + + // - Step 2.1 Set slice tensor with value + + // NOTE(liym27): [ Why resize slice_tensor here? ] + // A: When do broadcasting on slice_tensor and value, the shape of + // slice_tensor should be decreased dims. + // e.g. + // x[:,0] = value + // x's shape = [3, 4], value's shape = [3] + // We get slice_dims = [3, 1], decrease_slice_dims = [3] + // If do broadcasting on Tensor with shape [3, 1] and [3], the result's + // shape is [3, 3], which cross the border; + // If do broadcasting on Tensor with shape [3] and [3], the result's shape + // is [3], which is right. + + slice_tensor.Resize(slice_dims_for_assign); + CheckIsDimsMatch(slice_dims_for_assign, value.dims()); + // ElementwiseComputeEx can do broadcasting + paddle::operators::ElementwiseComputeEx, Context, T>( + ctx, &slice_tensor, value, -1, SubFunctor(), &slice_tensor); + + slice_tensor.Resize(slice_dims); + + // - Step 2.2 Pad slice tensor with 0 + pad_e.device(eigen_place) = pad_e.constant(T(0)); + pad_e.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(eigen_place) = slice_e; + + // Step 3: Set out tensor with value + out_e.device(eigen_place) = out_e - pad_e; +} + +template +void SetTensorValueKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& value, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* out) { + const int rank = x.dims().size(); + + switch (rank) { + case 1: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 2: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 3: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 4: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 5: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 6: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + default: + PADDLE_THROW(errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); + } +} + +template +void SetValueKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + const std::vector& shape, + const std::vector& values, + DenseTensor* out) { + const std::vector& assgin_values; + assgin_values.reserve(values.size()); + for (const auto& val : values) { + assgin_values.push_back(val.to()); + } + DensorTensor value_tensor = Empty(dev_ctx, shape); + paddle::framework::TensorFromVector(assgin_values, dev_ctx, &value_tensor); + + SetTensorValueKernel(dev_ctx, + x, + value_tensor, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); +} + +} // namespace phi diff --git a/paddle/phi/kernels/set_value_kernel.h b/paddle/phi/kernels/set_value_kernel.h new file mode 100644 index 00000000000000..271691b1a3596f --- /dev/null +++ b/paddle/phi/kernels/set_value_kernel.h @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" + +namespace phi { + +template +void SetTensorValueKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& value, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* out); + +template +void SetValueKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + const std::vector& shape, + const std::vector& values, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/set_value_sig.cc b/paddle/phi/ops/compat/set_value_sig.cc new file mode 100644 index 00000000000000..104eb0dc6eafb8 --- /dev/null +++ b/paddle/phi/ops/compat/set_value_sig.cc @@ -0,0 +1,616 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("Input")) { + if (ctx.HasInput("StartsTensorList")) { + if (ctx.HasInput("EndsTensorList")) { + if (ctx.HasInput("StepsTensorList")) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } else { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } + } else { + if (ctx.HasInput("StepsTensorList")) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } else { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values")) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } + } + } else { + if (ctx.HasInput("EndsTensorList")) { + if (ctx.HasInput("StepsTensorList")) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } else { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } + } else { + if (ctx.HasInput("StepsTensorList")) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } else { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values")) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } + } + } + } + return KernelSignature("unregistered", {}, {}, {}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(set_value, phi::SetValueOpArgumentMapping); From d9977a8057206379c1eeb1f2953d056da484fed3 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sun, 6 Mar 2022 05:37:26 +0000 Subject: [PATCH 2/3] fix bug of set_value --- paddle/fluid/framework/operator.cc | 22 +- paddle/fluid/framework/operator.h | 4 +- paddle/fluid/imperative/execution_context.h | 5 + paddle/fluid/imperative/prepared_operator.h | 7 +- .../phi/kernels/impl/set_value_kernel_impl.h | 69 +++--- paddle/phi/ops/compat/set_value_sig.cc | 200 ++++++++++++++---- 6 files changed, 231 insertions(+), 76 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d1bde8ba70bd1a..2ad3b5056b5d97 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -539,6 +539,20 @@ bool ExecutionContext::HasInput(const std::string& name) const { return var != nullptr; } +bool ExecutionContext::HasInputs(const std::string& name) const { + const auto& ins = ctx_.inputs; + auto it = ins.find(name); + if (it == ins.end() || it->second.empty()) { + return false; + } + for (const auto* input : it->second) { + if (input == nullptr) { + return false; + } + } + return true; +} + bool ExecutionContext::HasOutput(const std::string& name) const { auto* var = OutputVar(name); return var != nullptr; @@ -1186,9 +1200,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // phase phi::KernelKey pt_kernel_key; std::string pt_kernel_name; - VLOG(1) << "########### " << phi::KernelFactory::Instance(); if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) { - VLOG(1) << "######### HasCompatiblePhiKernel"; if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { pt_kernel_signature_.reset( new KernelSignature(std::move(GetExpectedPhiKernelArgs(exe_ctx)))); @@ -2268,7 +2280,11 @@ void OperatorWithKernel::BuildPhiKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + std::type_index(typeid(std::vector))) { + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { // Emplace Back Attr according to the type of Phi_Kernel args. const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); const std::vector vector_int64_attr(vector_int_attr.begin(), diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index e33d4feb82a9e7..1a1171f1dba4d7 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -295,6 +295,8 @@ class ExecutionContext { virtual bool HasInput(const std::string& name) const; + virtual bool HasInputs(const std::string& name) const; + virtual bool HasOutput(const std::string& name) const; virtual size_t InputSize(const std::string& name) const { @@ -449,7 +451,7 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { : ctx_(ctx) {} bool HasInput(const std::string& name) const override { - return ctx_.HasInput(name); + return ctx_.HasInputs(name); } bool HasOutput(const std::string& name) const override { diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index fe5ac73b004691..fbc47f81fd3316 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -133,6 +133,11 @@ class DygraphExecutionContext : public framework::ExecutionContext { return (it != var_map_in_.end() && it->second.size() > 0); } + bool HasInputs(const std::string& name) const override { + auto it = var_map_in_.find(name); + return (it != var_map_in_.end() && it->second.size() > 0); + } + bool HasOutput(const std::string& name) const override { auto it = var_map_out_.find(name); return (it != var_map_out_.end() && it->second.size() > 0); diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 90ad3afbe53fcc..d7c0c8cc547e6b 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -332,6 +332,7 @@ void BuildDygraphPhiKernelContext( } for (size_t i = 0; i < attr_names.size(); ++i) { + VLOG(1) << "############## attr_name: " << i << " : " << attr_names[i]; if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) { if (attrs.find(attr_names[i]) != attrs.end()) { // shape is in the attribute @@ -486,7 +487,11 @@ void BuildDygraphPhiKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + std::type_index(typeid(std::vector))) { + kernel_ctx->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { // Emplace Back Attr according to the type of Phi_Kernel args. const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); const std::vector vector_int64_attr(vector_int_attr.begin(), diff --git a/paddle/phi/kernels/impl/set_value_kernel_impl.h b/paddle/phi/kernels/impl/set_value_kernel_impl.h index 5d98cf5a8e8516..5aebffe51b5e38 100644 --- a/paddle/phi/kernels/impl/set_value_kernel_impl.h +++ b/paddle/phi/kernels/impl/set_value_kernel_impl.h @@ -20,18 +20,19 @@ #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/slice_utils.h" namespace phi { // check whether the tensor with dimension of second can assign to the // tensor with dimension of first -inline void CheckIsDimsMatch(const DDim first, const DDim second) { +inline void CheckIsDimsMatch(const DDim& first, const DDim& second) { int ignore_axis1 = 0, ignore_axis2 = 0; for (; ignore_axis1 < first.size(); ++ignore_axis1) { if (first[ignore_axis1] != 1) { @@ -70,7 +71,7 @@ inline void CheckIsDimsMatch(const DDim first, const DDim second) { first.to_str())); } -template +template void SetValueImpl(const Context& dev_ctx, const DenseTensor& in, const DenseTensor& value, @@ -87,9 +88,10 @@ void SetValueImpl(const Context& dev_ctx, std::vector steps_local = steps.GetData(); paddle::operators::CheckAndUpdateSliceAttrs( in_dims, axes, &starts_local, &ends_local, &steps_local); - auto slice_dims = - GetSliceDims(in_dims, axes, starts_local, ends_local, &steps_local); - auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes); + auto slice_dims = paddle::operators::GetSliceDims( + in_dims, axes, starts_local, ends_local, &steps_local); + auto decrease_slice_dims = + paddle::operators::GetDecreasedDims(slice_dims, decrease_axes); auto slice_dims_for_assign = decrease_slice_dims; if (!none_axes.empty()) { @@ -133,24 +135,23 @@ void SetValueImpl(const Context& dev_ctx, // set_value is what we want. Copy(dev_ctx, in, place, false, out); - Tensor slice_tensor = + DenseTensor slice_tensor = Empty(dev_ctx, ScalarArray{slice_dims.Get(), slice_dims.size()}); - Tensor pad_tensor = + DenseTensor pad_tensor = Empty(dev_ctx, ScalarArray{in_dims.Get(), in_dims.size()}); - auto pad_e = framework::EigenTensor::From(pad_tensor, in_dims); - auto out_e = framework::EigenTensor::From(*out); - auto slice_e = - framework::EigenTensor::From(slice_tensor, slice_dims); + auto pad_e = EigenTensor::From(pad_tensor, in_dims); + auto out_e = EigenTensor::From(*out); + auto slice_e = EigenTensor::From(slice_tensor, slice_dims); // Step 1: Set the value of out at `_index` to zero slice_e.device(eigen_place) = slice_e.constant(T(0)); - auto starts_indices = Eigen::DSizes(); - auto ends_indices = Eigen::DSizes(); - auto strides_indices = Eigen::DSizes(); + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); - for (size_t i = 0; i < D; ++i) { + for (size_t i = 0; i < RANK; ++i) { starts_indices[i] = 0; ends_indices[i] = slice_dims[i]; strides_indices[i] = 1; @@ -189,8 +190,13 @@ void SetValueImpl(const Context& dev_ctx, slice_tensor.Resize(slice_dims_for_assign); CheckIsDimsMatch(slice_dims_for_assign, value.dims()); // ElementwiseComputeEx can do broadcasting - paddle::operators::ElementwiseComputeEx, Context, T>( - ctx, &slice_tensor, value, -1, SubFunctor(), &slice_tensor); + funcs::ElementwiseCompute, T>( + dev_ctx, + slice_tensor, + value, + -1, + funcs::SubtractFunctor(), + &slice_tensor); slice_tensor.Resize(slice_dims); @@ -307,24 +313,25 @@ void SetValueKernel(const Context& dev_ctx, const std::vector& shape, const std::vector& values, DenseTensor* out) { - const std::vector& assgin_values; + std::vector assgin_values; assgin_values.reserve(values.size()); for (const auto& val : values) { assgin_values.push_back(val.to()); } - DensorTensor value_tensor = Empty(dev_ctx, shape); + DenseTensor value_tensor = Empty(dev_ctx, shape); paddle::framework::TensorFromVector(assgin_values, dev_ctx, &value_tensor); - - SetTensorValueKernel(dev_ctx, - x, - value_tensor, - starts, - ends, - steps, - axes, - decrease_axes, - none_axes, - out); + value_tensor.Resize(phi::make_ddim(shape)); + + SetTensorValueKernel(dev_ctx, + x, + value_tensor, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); } } // namespace phi diff --git a/paddle/phi/ops/compat/set_value_sig.cc b/paddle/phi/ops/compat/set_value_sig.cc index 104eb0dc6eafb8..eacfff26d53cf1 100644 --- a/paddle/phi/ops/compat/set_value_sig.cc +++ b/paddle/phi/ops/compat/set_value_sig.cc @@ -32,7 +32,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "decrease_axes", "none_axes"}, {"Out"}); - } else if (ctx.HasAttr("fp32_values")) { + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -44,7 +47,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp32_values"}, {"Out"}); - } else if (ctx.HasAttr("fp64_values")) { + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -56,7 +62,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp64_values"}, {"Out"}); - } else if (ctx.HasAttr("int32_values")) { + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -68,7 +77,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int32_values"}, {"Out"}); - } else if (ctx.HasAttr("int64_values")) { + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -80,7 +92,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int64_values"}, {"Out"}); - } else if (ctx.HasAttr("bool_values")) { + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -104,7 +119,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "decrease_axes", "none_axes"}, {"Out"}); - } else if (ctx.HasAttr("fp32_values")) { + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -116,7 +134,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp32_values"}, {"Out"}); - } else if (ctx.HasAttr("fp64_values")) { + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -128,7 +149,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp64_values"}, {"Out"}); - } else if (ctx.HasAttr("int32_values")) { + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -140,7 +164,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int32_values"}, {"Out"}); - } else if (ctx.HasAttr("int64_values")) { + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -152,7 +179,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int64_values"}, {"Out"}); - } else if (ctx.HasAttr("bool_values")) { + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -178,7 +208,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "decrease_axes", "none_axes"}, {"Out"}); - } else if (ctx.HasAttr("fp32_values")) { + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -190,7 +223,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp32_values"}, {"Out"}); - } else if (ctx.HasAttr("fp64_values")) { + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -202,7 +238,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp64_values"}, {"Out"}); - } else if (ctx.HasAttr("int32_values")) { + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -214,7 +253,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int32_values"}, {"Out"}); - } else if (ctx.HasAttr("int64_values")) { + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -226,7 +268,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int64_values"}, {"Out"}); - } else if (ctx.HasAttr("bool_values")) { + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -250,7 +295,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "decrease_axes", "none_axes"}, {"Out"}); - } else if (ctx.HasAttr("fp32_values")) { + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -262,7 +310,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp32_values"}, {"Out"}); - } else if (ctx.HasAttr("fp64_values")) { + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -274,7 +325,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp64_values"}, {"Out"}); - } else if (ctx.HasAttr("int32_values")) { + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -286,7 +340,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int32_values"}, {"Out"}); - } else if (ctx.HasAttr("int64_values")) { + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -298,7 +355,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int64_values"}, {"Out"}); - } else if (ctx.HasAttr("bool_values")) { + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"StartsTensorList", @@ -326,7 +386,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "decrease_axes", "none_axes"}, {"Out"}); - } else if (ctx.HasAttr("fp32_values")) { + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -338,7 +401,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp32_values"}, {"Out"}); - } else if (ctx.HasAttr("fp64_values")) { + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -350,7 +416,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp64_values"}, {"Out"}); - } else if (ctx.HasAttr("int32_values")) { + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -362,7 +431,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int32_values"}, {"Out"}); - } else if (ctx.HasAttr("int64_values")) { + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -374,7 +446,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int64_values"}, {"Out"}); - } else if (ctx.HasAttr("bool_values")) { + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -398,7 +473,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "decrease_axes", "none_axes"}, {"Out"}); - } else if (ctx.HasAttr("fp32_values")) { + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -410,7 +488,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp32_values"}, {"Out"}); - } else if (ctx.HasAttr("fp64_values")) { + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -422,7 +503,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp64_values"}, {"Out"}); - } else if (ctx.HasAttr("int32_values")) { + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -434,7 +518,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int32_values"}, {"Out"}); - } else if (ctx.HasAttr("int64_values")) { + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -446,7 +533,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int64_values"}, {"Out"}); - } else if (ctx.HasAttr("bool_values")) { + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -472,7 +562,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "decrease_axes", "none_axes"}, {"Out"}); - } else if (ctx.HasAttr("fp32_values")) { + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -484,7 +577,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp32_values"}, {"Out"}); - } else if (ctx.HasAttr("fp64_values")) { + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -496,7 +592,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp64_values"}, {"Out"}); - } else if (ctx.HasAttr("int32_values")) { + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -508,7 +607,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int32_values"}, {"Out"}); - } else if (ctx.HasAttr("int64_values")) { + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -520,7 +622,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int64_values"}, {"Out"}); - } else if (ctx.HasAttr("bool_values")) { + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -544,7 +649,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "decrease_axes", "none_axes"}, {"Out"}); - } else if (ctx.HasAttr("fp32_values")) { + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -556,7 +664,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp32_values"}, {"Out"}); - } else if (ctx.HasAttr("fp64_values")) { + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -568,7 +679,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "fp64_values"}, {"Out"}); - } else if (ctx.HasAttr("int32_values")) { + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -580,7 +694,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int32_values"}, {"Out"}); - } else if (ctx.HasAttr("int64_values")) { + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", @@ -592,7 +709,10 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "int64_values"}, {"Out"}); - } else if (ctx.HasAttr("bool_values")) { + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { return KernelSignature("set_value", {"Input"}, {"starts", From d5983b2f586a329299c9cecb26b489c205d5244e Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 7 Mar 2022 05:30:17 +0000 Subject: [PATCH 3/3] add coverage test --- paddle/fluid/framework/infershape_utils.cc | 54 --- paddle/fluid/framework/operator.cc | 9 - paddle/phi/tests/ops/test_op_signature.cc | 370 +++++++++++++++++++++ 3 files changed, 370 insertions(+), 63 deletions(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 411290c74b2bef..7232a707916dd5 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -437,60 +437,6 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name, infershape_input.size())); } } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { - auto& attr = attr_reader.GetAttr(attr_name); - if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to vector when " - "construct InferMetaContext.", - attr_name)); - } } else if (ctx->HasAttr(attr_name)) { // Emplace Back Attr according to the type of attr. auto& attr = attr_reader.GetAttr(attr_name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2ad3b5056b5d97..f8e30c1ee294ec 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2242,15 +2242,6 @@ void OperatorWithKernel::BuildPhiKernelContext( scalar_list.emplace_back(val); } pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to vector when " diff --git a/paddle/phi/tests/ops/test_op_signature.cc b/paddle/phi/tests/ops/test_op_signature.cc index a6c9a27de7dc5a..88c9193a8f8949 100644 --- a/paddle/phi/tests/ops/test_op_signature.cc +++ b/paddle/phi/tests/ops/test_op_signature.cc @@ -114,5 +114,375 @@ TEST(ARG_MAP, fill_constant) { ASSERT_EQ(signature9.name, "full_sr"); } +TEST(ARG_MAP, set_value) { + TestArgumentMappingContext arg_case( + {"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"}, + {}, + {{"fp32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case).name, + "set_value"); + + TestArgumentMappingContext arg_case1( + {"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case1).name, + "set_value"); + + TestArgumentMappingContext arg_case2( + {"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case2).name, + "set_value"); + + TestArgumentMappingContext arg_case3( + {"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case3).name, + "set_value"); + + TestArgumentMappingContext arg_case4( + {"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case4).name, + "set_value"); + + TestArgumentMappingContext arg_case5( + {"Input", "StartsTensorList", "EndsTensorList", "ValueTensor"}, + {}, + {}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case5).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case6( + {"Input", "StartsTensorList", "EndsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case6).name, + "set_value"); + + TestArgumentMappingContext arg_case7( + {"Input", "StartsTensorList", "EndsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case7).name, + "set_value"); + + TestArgumentMappingContext arg_case8( + {"Input", "StartsTensorList", "EndsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case8).name, + "set_value"); + + TestArgumentMappingContext arg_case9( + {"Input", "StartsTensorList", "EndsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case9).name, + "set_value"); + + TestArgumentMappingContext arg_case10( + {"Input", "StartsTensorList", "StepsTensorList", "ValueTensor"}, + {}, + {}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case10).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case11( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case11).name, + "set_value"); + + TestArgumentMappingContext arg_case12( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case12).name, + "set_value"); + + TestArgumentMappingContext arg_case13( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case13).name, + "set_value"); + + TestArgumentMappingContext arg_case14( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case14).name, + "set_value"); + + TestArgumentMappingContext arg_case15( + {"Input", "StartsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case15).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case16( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"fp32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case16).name, + "set_value"); + + TestArgumentMappingContext arg_case17( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case17).name, + "set_value"); + + TestArgumentMappingContext arg_case18( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case18).name, + "set_value"); + + TestArgumentMappingContext arg_case19( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case19).name, + "set_value"); + + TestArgumentMappingContext arg_case20( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case20).name, + "set_value"); + + TestArgumentMappingContext arg_case21( + {"Input", "EndsTensorList", "StepsTensorList", "ValueTensor"}, + {}, + {}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case21).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case22( + {"Input", "EndsTensorList", "StepsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case22).name, + "set_value"); + + TestArgumentMappingContext arg_case23( + {"Input", "EndsTensorList", "StepsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case23).name, + "set_value"); + + TestArgumentMappingContext arg_case24( + {"Input", "EndsTensorList", "StepsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case24).name, + "set_value"); + + TestArgumentMappingContext arg_case25( + {"Input", "EndsTensorList", "StepsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case25).name, + "set_value"); + + TestArgumentMappingContext arg_case26( + {"Input", "EndsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case26).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case27( + {"Input", "EndsTensorList"}, + {}, + {{"fp32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case27).name, + "set_value"); + + TestArgumentMappingContext arg_case28( + {"Input", "EndsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case28).name, + "set_value"); + + TestArgumentMappingContext arg_case29( + {"Input", "EndsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case29).name, + "set_value"); + + TestArgumentMappingContext arg_case30( + {"Input", "EndsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case30).name, + "set_value"); + + TestArgumentMappingContext arg_case31( + {"Input", "EndsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case31).name, + "set_value"); + + TestArgumentMappingContext arg_case32( + {"Input", "StepsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case32).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case33( + {"Input", "StepsTensorList"}, + {}, + {{"fp32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case33).name, + "set_value"); + + TestArgumentMappingContext arg_case34( + {"Input", "StepsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case34).name, + "set_value"); + + TestArgumentMappingContext arg_case35( + {"Input", "StepsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case35).name, + "set_value"); + + TestArgumentMappingContext arg_case36( + {"Input", "StepsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case36).name, + "set_value"); + + TestArgumentMappingContext arg_case37( + {"Input", "StepsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case37).name, + "set_value"); +} + } // namespace tests } // namespace phi