Skip to content

Commit

Permalink
[XPU] support slice int32; fix slice, expand; test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Feb 8, 2021
1 parent bd7af40 commit 395104e
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 31 deletions.
70 changes: 44 additions & 26 deletions lite/kernels/xpu/slice_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,34 @@ namespace lite {
namespace kernels {
namespace xpu {

void SliceCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
template <class T>
void SliceCompute<T>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();

auto x_dims = param.X->dims();
std::vector<int> x_shape_;
std::vector<int> x_dim_begin_;
std::vector<int> x_dim_end_;
x_shape_.resize(x_dims.size());
x_dim_begin_.resize(x_dims.size());
x_dim_end_.resize(x_dims.size());
auto x_shape = x_dims.Vectorize();
std::vector<int> x_shape_(x_shape.begin(), x_shape.end());
std::vector<int> x_dim_begin_(x_dims.size(), 0);
std::vector<int> x_dim_end_(x_shape_);

for (size_t i = 0; i < x_dims.size(); ++i) {
x_shape_[i] = static_cast<int>(x_dims[i]);
x_dim_begin_[i] = 0;
x_dim_end_[i] = x_shape_[i];
}
for (size_t i = 0; i < param.axes.size(); ++i) {
int axis = param.axes[i];
x_dim_begin_[axis] = param.starts[i];
x_dim_end_[axis] = param.ends[i];
x_dim_begin_[axis] = param.starts[i] < 0
? param.starts[i] + static_cast<int>(x_dims[axis])
: param.starts[i];
int end = param.ends[i] < 0 ? param.ends[i] + static_cast<int>(x_dims[axis])
: param.ends[i];
x_dim_end_[axis] = (std::min)(end, static_cast<int>(x_dims[axis]));
}
int r = xdnn::slice(ctx.GetRawContext(), /* context */
param.X->data<float>(), /* in */
param.Out->mutable_data<float>(TARGET(kXPU)), /* out */
x_shape_,
x_dim_begin_,
x_dim_end_);

int r =
xdnn::slice(ctx.GetRawContext(), /* context */
param.X->template data<T>(), /* in */
param.Out->template mutable_data<T>(TARGET(kXPU)), /* out */
x_shape_,
x_dim_begin_,
x_dim_end_);

CHECK_EQ(r, 0);
}
Expand All @@ -58,8 +58,26 @@ void SliceCompute::Run() {
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(
slice, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::SliceCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
using SliceFloat32 = paddle::lite::kernels::xpu::SliceCompute<float>;
REGISTER_LITE_KERNEL(slice, kXPU, kFloat, kAny, SliceFloat32, float32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kXPU),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kXPU),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.Finalize();

using SliceInt32 = paddle::lite::kernels::xpu::SliceCompute<int32_t>;
REGISTER_LITE_KERNEL(slice, kXPU, kFloat, kAny, SliceInt32, int32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kXPU),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kXPU),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.Finalize();
4 changes: 3 additions & 1 deletion lite/kernels/xpu/slice_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ namespace lite {
namespace kernels {
namespace xpu {

class SliceCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
template <class T>
class SliceCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat), DATALAYOUT(kAny)> {
public:
using param_t = operators::SliceParam;

Expand Down
21 changes: 18 additions & 3 deletions lite/operators/expand_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ bool ExpandOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Out);

int x_dims_size = param_.X->dims().size();
CHECK_LE(x_dims_size, 6u)
CHECK_LE(x_dims_size, 6)
<< "The rank of Input(X) must not be greater than 6.";

int expand_size = 0;
Expand All @@ -43,9 +43,24 @@ bool ExpandOpLite::CheckShape() const {
}

bool ExpandOpLite::InferShapeImpl() const {
std::vector<int> expand_times;
if (param_.ExpandTimes != nullptr) {
auto expand_times_data = param_.ExpandTimes->template data<int>();
for (int64_t i = 0; i < param_.ExpandTimes->numel(); i++) {
expand_times.push_back(expand_times_data[i]);
}
} else if (!param_.expand_times_tensor.empty()) {
for (size_t i = 0; i < param_.expand_times_tensor.size(); i++) {
expand_times.push_back(
param_.expand_times_tensor[i]->template data<int>()[0]);
}
} else {
expand_times = param_.expand_times;
}

DDim out_dims(param_.X->dims());
for (size_t i = 0; i < param_.expand_times.size(); ++i) {
out_dims[i] *= param_.expand_times[i];
for (size_t i = 0; i < expand_times.size(); ++i) {
out_dims[i] *= static_cast<int64_t>(expand_times[i]);
}
param_.Out->Resize(out_dims);
return true;
Expand Down
1 change: 0 additions & 1 deletion lite/tests/kernels/elementwise_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ T div(T a, T b) {

template <class T>
T floordiv(T a, T b) {
LOG(INFO) << "--- k, " << a << ", " << b;
return static_cast<T>(std::trunc(a / b));
}

Expand Down

0 comments on commit 395104e

Please sign in to comment.