diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index 3e9e0f89d43..da32b283c19 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -49,6 +49,7 @@ add_kernel(sin_compute_host Host extra SRCS sin_compute.cc DEPS ${lite_kernel_de add_kernel(cos_compute_host Host extra SRCS cos_compute.cc DEPS ${lite_kernel_deps}) add_kernel(crop_compute_host Host extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_host) add_kernel(crop_tensor_compute_host Host extra SRCS crop_tensor_compute.cc DEPS ${lite_kernel_deps} math_host) +add_kernel(sequence_pad_compute_host Host extra SRCS sequence_pad_compute.cc DEPS ${lite_kernel_deps} math_host) add_kernel(sequence_unpad_compute_host Host extra SRCS sequence_unpad_compute.cc DEPS ${lite_kernel_deps} math_host) add_kernel(sequence_expand_compute_host Host extra SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_softmax_compute_host Host extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps}) diff --git a/lite/kernels/host/sequence_pad_compute.cc b/lite/kernels/host/sequence_pad_compute.cc new file mode 100644 index 00000000000..527b46727cd --- /dev/null +++ b/lite/kernels/host/sequence_pad_compute.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/host/sequence_pad_compute.h" +#include "lite/backends/host/math/sequence_padding.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +template +void SequencePadCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + auto* x = param.X; + auto* pad_value = param.PadValue; + auto* len_t = param.Length; + auto* out = param.Out; + CHECK(!x->lod().empty()) << "Input X should have lod data."; + int padded_length = param.padded_length; + + lite::host::math::PaddingLoDTensorFunctor()( + ctx, + *x, + out, + *pad_value, + padded_length, + 0, + false, + lite::host::math::kBatchLengthWidth); + + auto* len_data = len_t->template mutable_data(); + auto x_lod = x->lod(); + for (size_t i = 1; i < x_lod[0].size(); i++) { + len_data[i - 1] = x_lod[0][i] - x_lod[0][i - 1]; + } +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_pad, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::SequencePadCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))}) + .BindInput("PadValue", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))}) + .BindOutput("Length", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))}) + .Finalize(); + +REGISTER_LITE_KERNEL(sequence_pad, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::SequencePadCompute, + int32) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("PadValue", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("Length", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))}) + .Finalize(); + +REGISTER_LITE_KERNEL(sequence_pad, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::SequencePadCompute, + int64) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))}) + .BindInput("PadValue", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))}) + .BindOutput("Length", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))}) + .Finalize(); diff --git a/lite/kernels/host/sequence_pad_compute.h b/lite/kernels/host/sequence_pad_compute.h new file mode 100644 index 00000000000..e1bd5b07271 --- /dev/null +++ b/lite/kernels/host/sequence_pad_compute.h @@ -0,0 +1,37 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +template +class SequencePadCompute : public KernelLite { + public: + using param_t = operators::SequencePadParam; + + void Run() override; + + virtual ~SequencePadCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index f17f988cd20..1a5bb0559dd 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -64,6 +64,7 @@ if((NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LITE_WITH_MLU) AND (LITE_WIT if(LITE_BUILD_EXTRA) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS ${test_kernel_deps}) + lite_cc_test(test_sequence_pad SRCS sequence_pad_test.cc DEPS ${test_kernel_deps}) lite_cc_test(test_correlation SRCS correlation_test.cc DEPS ${test_kernel_deps}) #lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${test_kernel_deps}) lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${test_kernel_deps}) diff --git a/lite/tests/kernels/sequence_pad_test.cc b/lite/tests/kernels/sequence_pad_test.cc new file mode 100644 index 00000000000..cd312e2187c --- /dev/null +++ b/lite/tests/kernels/sequence_pad_test.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { + +template +class SequencePadTester : public arena::TestCase { + protected: + std::string x_ = "x"; + std::string pad_value_ = "pad_value"; + std::string out_ = "out"; + std::string length_ = "length"; + DDim x_dims_{{9, 2, 3, 4}}; + LoD x_lod_{{{0, 2, 5, 9}}}; + T value_ = 0; + int padded_length_ = 4; + + public: + SequencePadTester(const Place& place, const std::string& alias) + : TestCase(place, alias) {} + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(out_); + auto out_shape = x_dims_.Vectorize(); + out_shape[0] = padded_length_; + out_shape.insert(out_shape.begin(), + static_cast(x_lod_[0].size() - 1)); + out->Resize(out_shape); + auto* out_data = out->template mutable_data(); + for (int64_t i = 0; i < out->numel(); i++) { + out_data[i] = value_; + } + + int n = x_dims_.production() / x_dims_[0]; + int out_step = padded_length_ * n; + auto* x = scope->FindTensor(x_); + auto* x_data = x->template data(); + for (size_t i = 1; i < x_lod_[0].size(); i++) { + int x_step = (x_lod_[0][i] - x_lod_[0][i - 1]) * n; + memcpy(out_data, x_data, sizeof(T) * x_step); + x_data += x_step; + out_data += out_step; + } + + auto* length = scope->NewTensor(length_); + length->Resize({static_cast(x_lod_[0].size() - 1)}); + int64_t* length_data = length->template mutable_data(); + for (size_t i = 1; i < x_lod_[0].size(); i++) { + length_data[i - 1] = x_lod_[0][i] - x_lod_[0][i - 1]; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("sequence_pad"); + op_desc->SetInput("X", {x_}); + op_desc->SetInput("PadValue", {pad_value_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetOutput("Length", {length_}); + op_desc->SetAttr("padded_length", padded_length_); + } + + void PrepareData() override { + std::vector x_data(x_dims_.production()); + fill_data_rand(x_data.data(), -10, 10, x_dims_.production()); + SetCommonTensor(x_, x_dims_, x_data.data(), x_lod_); + + std::vector pad_value_data{0}; + SetCommonTensor(pad_value_, DDim{{1}}, pad_value_data.data()); + } +}; + +template +void TestSequencePad(const Place place, + const float abs_error, + const std::string alias) { + std::unique_ptr tester( + new SequencePadTester(place, alias)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); +} + +TEST(sequence_pad, precision) { + Place place; + float abs_error = 1e-5; +#if defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) + place = TARGET(kHost); +#else + return; +#endif + + TestSequencePad(place, abs_error, "def"); + TestSequencePad(place, abs_error, "int32"); + TestSequencePad(place, abs_error, "int64"); +} + +} // namespace lite +} // namespace paddle