Skip to content

Commit

Permalink
[XPU] support slice int32; fix slice, expand (PaddlePaddle#5488)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Feb 25, 2021
1 parent 06b8759 commit 9b09a64
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 50 deletions.
35 changes: 29 additions & 6 deletions lite/kernels/host/expand_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ REGISTER_LITE_KERNEL(expand, kHost, kFloat, kAny, expand_float, def)
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindInput("expand_times_tensor",
{LiteType::GetTensorListTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
Expand All @@ -111,11 +111,34 @@ REGISTER_LITE_KERNEL(expand, kHost, kInt32, kAny, expand_int32, def)
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindInput("expand_times_tensor",
{LiteType::GetTensorListTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.Finalize();

#ifdef LITE_BUILD_EXTRA
using expand_int32_f =
paddle::lite::kernels::host::ExpandCompute<int, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(expand, kHost, kFloat, kAny, expand_int32_f, int32)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindInput("ExpandTimes",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindInput("expand_times_tensor",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.Finalize();
#endif // LITE_BUILD_EXTRA
13 changes: 11 additions & 2 deletions lite/kernels/xpu/fill_constant_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ void FillConstantCompute::Run() {
static_cast<int32_t>(param.value));
break;
}
case 3: {
auto data = param.out->mutable_data<int64_t>(TARGET(kXPU));
r = xdnn::constant<int64_t>(ctx.GetRawContext(),
data,
write_size,
static_cast<int64_t>(param.value));
break;
}
case 5: {
auto data = param.out->mutable_data<float>(TARGET(kXPU));
r = xdnn::constant<float>(ctx.GetRawContext(),
Expand All @@ -54,7 +62,8 @@ void FillConstantCompute::Run() {
}
default: {
LOG(FATAL) << "Attribute dtype in fill_constant op "
"must be 1[int16] or 2[int32] or 5[fp32] for xpu: "
"must be 1[int16] or 3[int64] or 2[int32] or 5[fp32] "
"for xpu: "
<< param.dtype;
break;
}
Expand All @@ -78,5 +87,5 @@ REGISTER_LITE_KERNEL(fill_constant,
.BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
.BindPaddleOpVersion("fill_constant", 1)
.BindPaddleOpVersion("fill_constant", 2)
.Finalize();
76 changes: 45 additions & 31 deletions lite/kernels/xpu/slice_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,35 @@ namespace lite {
namespace kernels {
namespace xpu {

void SliceCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
auto x_dims = param.X->dims();
x_shape_.reserve(x_dims.size());
x_dim_begin_.reserve(x_dims.size());
x_dim_end_.reserve(x_dims.size());
}

void SliceCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
template <class T>
void SliceCompute<T>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();

auto x_dims = param.X->dims();
for (size_t i = 0; i < x_dims.size(); ++i) {
x_shape_[i] = x_dims[i];
x_dim_begin_[i] = 0;
x_dim_end_[i] = x_dims[i];
}
auto x_shape = x_dims.Vectorize();
std::vector<int> x_shape_(x_shape.begin(), x_shape.end());
std::vector<int> x_dim_begin_(x_dims.size(), 0);
std::vector<int> x_dim_end_(x_shape_);

for (size_t i = 0; i < param.axes.size(); ++i) {
int axis = param.axes[i];
x_dim_begin_[axis] = param.starts[i];
x_dim_end_[axis] = param.ends[i];
x_dim_begin_[axis] = param.starts[i] < 0
? param.starts[i] + static_cast<int>(x_dims[axis])
: param.starts[i];
int end = param.ends[i] < 0 ? param.ends[i] + static_cast<int>(x_dims[axis])
: param.ends[i];
x_dim_end_[axis] = (std::min)(end, static_cast<int>(x_dims[axis]));
}

int ndim = param.X->dims().size();
int r = xdnn::slice_forward(
ctx.GetRawContext(), /* context */
&x_shape_[0], /* shape */
&x_dim_begin_[0], /* starts */
&x_dim_end_[0], /* ends */
ndim, /* n */
param.X->data<float>(), /* in */
param.Out->mutable_data<float>(TARGET(kXPU)) /* out */);
int r =
xdnn::slice(ctx.GetRawContext(), /* context */
param.X->template data<T>(), /* in */
param.Out->template mutable_data<T>(TARGET(kXPU)), /* out */
x_shape_,
x_dim_begin_,
x_dim_end_);

CHECK_EQ(r, 0);
}

Expand All @@ -62,8 +58,26 @@ void SliceCompute::Run() {
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(
slice, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::SliceCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
using SliceFloat32 = paddle::lite::kernels::xpu::SliceCompute<float>;
REGISTER_LITE_KERNEL(slice, kXPU, kFloat, kAny, SliceFloat32, float32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kXPU),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kXPU),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.Finalize();

using SliceInt32 = paddle::lite::kernels::xpu::SliceCompute<int32_t>;
REGISTER_LITE_KERNEL(slice, kXPU, kFloat, kAny, SliceInt32, int32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kXPU),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kXPU),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.Finalize();
11 changes: 3 additions & 8 deletions lite/kernels/xpu/slice_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,15 @@ namespace lite {
namespace kernels {
namespace xpu {

class SliceCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
template <class T>
class SliceCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat), DATALAYOUT(kAny)> {
public:
using param_t = operators::SliceParam;

virtual void PrepareForRun();

virtual void Run();

virtual ~SliceCompute() = default;

private:
std::vector<int> x_shape_;
std::vector<int> x_dim_begin_;
std::vector<int> x_dim_end_;
};

} // namespace xpu
Expand Down
21 changes: 18 additions & 3 deletions lite/operators/expand_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ bool ExpandOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Out);

int x_dims_size = param_.X->dims().size();
CHECK_LE(x_dims_size, 6u)
CHECK_LE(x_dims_size, 6)
<< "The rank of Input(X) must not be greater than 6.";

int expand_size = 0;
Expand All @@ -43,9 +43,24 @@ bool ExpandOpLite::CheckShape() const {
}

bool ExpandOpLite::InferShapeImpl() const {
std::vector<int> expand_times;
if (param_.ExpandTimes != nullptr) {
auto expand_times_data = param_.ExpandTimes->template data<int>();
for (int64_t i = 0; i < param_.ExpandTimes->numel(); i++) {
expand_times.push_back(expand_times_data[i]);
}
} else if (!param_.expand_times_tensor.empty()) {
for (size_t i = 0; i < param_.expand_times_tensor.size(); i++) {
expand_times.push_back(
param_.expand_times_tensor[i]->template data<int>()[0]);
}
} else {
expand_times = param_.expand_times;
}

DDim out_dims(param_.X->dims());
for (size_t i = 0; i < param_.expand_times.size(); ++i) {
out_dims[i] *= param_.expand_times[i];
for (size_t i = 0; i < expand_times.size(); ++i) {
out_dims[i] *= static_cast<int64_t>(expand_times[i]);
}
param_.Out->Resize(out_dims);
return true;
Expand Down

0 comments on commit 9b09a64

Please sign in to comment.