Skip to content

Commit

Permalink
[Host] add sequence_pad (#5867) (#6088)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang authored May 17, 2021
1 parent aace354 commit a285276
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 0 deletions.
1 change: 1 addition & 0 deletions lite/kernels/host/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ add_kernel(sin_compute_host Host extra SRCS sin_compute.cc DEPS ${lite_kernel_de
add_kernel(cos_compute_host Host extra SRCS cos_compute.cc DEPS ${lite_kernel_deps})
add_kernel(crop_compute_host Host extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_host)
add_kernel(crop_tensor_compute_host Host extra SRCS crop_tensor_compute.cc DEPS ${lite_kernel_deps} math_host)
add_kernel(sequence_pad_compute_host Host extra SRCS sequence_pad_compute.cc DEPS ${lite_kernel_deps} math_host)
add_kernel(sequence_unpad_compute_host Host extra SRCS sequence_unpad_compute.cc DEPS ${lite_kernel_deps} math_host)
add_kernel(flatten_contiguous_range_compute_host Host extra SRCS flatten_compute.cc DEPS ${lite_kernel_deps})
add_kernel(shuffle_channel_compute_host Host extra SRCS shuffle_channel_compute.cc DEPS ${lite_kernel_deps} math_host)
Expand Down
100 changes: 100 additions & 0 deletions lite/kernels/host/sequence_pad_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/host/sequence_pad_compute.h"
#include "lite/backends/host/math/sequence_padding.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

template <class T>
void SequencePadCompute<T>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<HostContext>();

auto* x = param.X;
auto* pad_value = param.PadValue;
auto* len_t = param.Length;
auto* out = param.Out;
CHECK(!x->lod().empty()) << "Input X should have lod data.";
int padded_length = param.padded_length;

lite::host::math::PaddingLoDTensorFunctor<lite::TargetType::kHost, T>()(
ctx,
*x,
out,
*pad_value,
padded_length,
0,
false,
lite::host::math::kBatchLengthWidth);

auto* len_data = len_t->template mutable_data<int64_t>();
auto x_lod = x->lod();
for (size_t i = 1; i < x_lod[0].size(); i++) {
len_data[i - 1] = x_lod[0][i] - x_lod[0][i - 1];
}
}

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(sequence_pad,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::SequencePadCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
.BindInput("PadValue",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
.BindOutput("Length",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.Finalize();

REGISTER_LITE_KERNEL(sequence_pad,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::SequencePadCompute<int>,
int32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindInput("PadValue",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Length",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.Finalize();

REGISTER_LITE_KERNEL(sequence_pad,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::SequencePadCompute<int64_t>,
int64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.BindInput("PadValue",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.BindOutput("Length",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.Finalize();
37 changes: 37 additions & 0 deletions lite/kernels/host/sequence_pad_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

template <class T>
class SequencePadCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::SequencePadParam;

void Run() override;

virtual ~SequencePadCompute() = default;
};

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/tests/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ if((NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LITE_WITH_MLU) AND (LITE_WIT
if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS ${test_kernel_deps})
lite_cc_test(test_scatter_nd_add SRCS scatter_nd_add_test.cc DEPS ${test_kernel_deps})
lite_cc_test(test_sequence_pad SRCS sequence_pad_test.cc DEPS ${test_kernel_deps})
#lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${test_kernel_deps})
lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${test_kernel_deps})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS ${test_kernel_deps})
Expand Down
116 changes: 116 additions & 0 deletions lite/tests/kernels/sequence_pad_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include <cstring>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"

namespace paddle {
namespace lite {

template <class T>
class SequencePadTester : public arena::TestCase {
protected:
std::string x_ = "x";
std::string pad_value_ = "pad_value";
std::string out_ = "out";
std::string length_ = "length";
DDim x_dims_{{9, 2, 3, 4}};
LoD x_lod_{{{0, 2, 5, 9}}};
T value_ = 0;
int padded_length_ = 4;

public:
SequencePadTester(const Place& place, const std::string& alias)
: TestCase(place, alias) {}

void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(out_);
auto out_shape = x_dims_.Vectorize();
out_shape[0] = padded_length_;
out_shape.insert(out_shape.begin(),
static_cast<int64_t>(x_lod_[0].size() - 1));
out->Resize(out_shape);
auto* out_data = out->template mutable_data<T>();
for (int64_t i = 0; i < out->numel(); i++) {
out_data[i] = value_;
}

int n = x_dims_.production() / x_dims_[0];
int out_step = padded_length_ * n;
auto* x = scope->FindTensor(x_);
auto* x_data = x->template data<T>();
for (size_t i = 1; i < x_lod_[0].size(); i++) {
int x_step = (x_lod_[0][i] - x_lod_[0][i - 1]) * n;
memcpy(out_data, x_data, sizeof(T) * x_step);
x_data += x_step;
out_data += out_step;
}

auto* length = scope->NewTensor(length_);
length->Resize({static_cast<int64_t>(x_lod_[0].size() - 1)});
int64_t* length_data = length->template mutable_data<int64_t>();
for (size_t i = 1; i < x_lod_[0].size(); i++) {
length_data[i - 1] = x_lod_[0][i] - x_lod_[0][i - 1];
}
}

void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("sequence_pad");
op_desc->SetInput("X", {x_});
op_desc->SetInput("PadValue", {pad_value_});
op_desc->SetOutput("Out", {out_});
op_desc->SetOutput("Length", {length_});
op_desc->SetAttr("padded_length", padded_length_);
}

void PrepareData() override {
std::vector<T> x_data(x_dims_.production());
fill_data_rand<T>(x_data.data(), -10, 10, x_dims_.production());
SetCommonTensor(x_, x_dims_, x_data.data(), x_lod_);

std::vector<T> pad_value_data{0};
SetCommonTensor(pad_value_, DDim{{1}}, pad_value_data.data());
}
};

template <class T>
void TestSequencePad(const Place place,
const float abs_error,
const std::string alias) {
std::unique_ptr<arena::TestCase> tester(
new SequencePadTester<T>(place, alias));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}

TEST(sequence_pad, precision) {
Place place;
float abs_error = 1e-5;
#if defined(LITE_WITH_ARM) || defined(LITE_WITH_X86)
place = TARGET(kHost);
#else
return;
#endif

TestSequencePad<float>(place, abs_error, "def");
TestSequencePad<int>(place, abs_error, "int32");
TestSequencePad<int64_t>(place, abs_error, "int64");
}

} // namespace lite
} // namespace paddle

0 comments on commit a285276

Please sign in to comment.