Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rois_num for roi_align xpu OP #28077

Merged
merged 9 commits into from
Oct 20, 2020
42 changes: 34 additions & 8 deletions paddle/fluid/operators/roi_align_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,40 @@ class XPUROIAlignOpKernel : public framework::OpKernel<T> {
int width = in_dims[3];
int rois_num = rois->dims()[0];
const T* input_data = in->data<T>();
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The rois_batch_size and imgs batch_size of roi_align_xpu OP must "
"be the same. But received rois_batch_size %d , batch_size %d",
rois_batch_size, batch_size));

framework::Tensor _roi_batch_list;
_roi_batch_list.Resize({rois_num});
int* rois_lod = _roi_batch_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size = 1;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of rois and the batch size of images "
" must be the same. But received the batch size of rois is %d, "
"and the batch size of images is %d",
rois_batch_size, batch_size));
auto* rois_num_data = rois_num_t->data<int>();
rois_lod[0] = 0;
for (int n = 0; n < rois_batch_size; ++n) {
rois_lod[n + 1] = rois_lod[n] + rois_num_data[n];
}
} else {
auto _rois_lod = rois->lod().back();
rois_batch_size = _rois_lod.size() - 1;
for (int n = 0; n < _rois_lod.size(); ++n) {
rois_lod[n] = _rois_lod[n];
}
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The rois_batch_size and imgs batch_size of roi_align_xpu OP "
"must "
"be the same. But received rois_batch_size %d , batch_size %d",
rois_batch_size, batch_size));
}
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(
rois_num, rois_num_with_lod,
Expand Down
24 changes: 24 additions & 0 deletions python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,29 @@ def test_check_output(self):
self.check_output_with_place(place)


class TestROIAlignInLodOp(TestROIAlignOp):
def set_data(self):
self.init_test_case()
self.make_rois()
self.calc_roi_align()

seq_len = self.rois_lod[0]

self.inputs = {
'X': self.x,
'ROIs': (self.rois[:, 1:5], self.rois_lod),
'RoisNum': np.asarray(seq_len).astype('int32')
}

self.attrs = {
'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width,
'sampling_ratio': self.sampling_ratio
}

self.outputs = {'Out': self.out_data}


if __name__ == '__main__':
unittest.main()