Skip to content

Commit

Permalink
[Refactor] Refactor the interface for RoIAlignRotated (#1662)
Browse files Browse the repository at this point in the history
* fix interface for RoIAlignRotated

* Add a unit test for RoIAlignRotated

* Make a unit test for RoIAlignRotated concise

* fix interface for RoIAlignRotated

* Refactor ext_module.nms_rotated

* Lint cpp files
  • Loading branch information
nijkah authored Feb 18, 2022
1 parent fccb109 commit b83bdb0
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 118 deletions.
16 changes: 8 additions & 8 deletions mmcv/ops/csrc/common/cuda/roi_align_rotated_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ template <typename scalar_t>
__global__ void roi_align_rotated_forward_cuda_kernel(
const int nthreads, const scalar_t *bottom_data,
const scalar_t *bottom_rois, const scalar_t spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, scalar_t *top_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
Expand Down Expand Up @@ -58,11 +58,11 @@ __global__ void roi_align_rotated_forward_cuda_kernel(
bottom_data + (roi_batch_ind * channels + c) * height * width;

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sample_num > 0)
? sample_num
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);

// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after.
Expand Down Expand Up @@ -104,7 +104,7 @@ __global__ void roi_align_rotated_forward_cuda_kernel(
template <typename scalar_t>
__global__ void roi_align_rotated_backward_cuda_kernel(
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
const scalar_t spatial_scale, const int sample_num, const bool aligned,
const scalar_t spatial_scale, const int sampling_ratio, const bool aligned,
const bool clockwise, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
Expand Down Expand Up @@ -146,11 +146,11 @@ __global__ void roi_align_rotated_backward_cuda_kernel(
const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sample_num > 0)
? sample_num
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);

// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after.
Expand Down
28 changes: 14 additions & 14 deletions mmcv/ops/csrc/parrots/roi_align_rotated_parrots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ void roi_align_rotated_forward_cuda_parrots(CudaContext& ctx,
int pooled_height;
int pooled_width;
float spatial_scale;
int sample_num;
int sampling_ratio;
bool aligned;
bool clockwise;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num)
.get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise)
.done();
Expand All @@ -30,7 +30,7 @@ void roi_align_rotated_forward_cuda_parrots(CudaContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
roi_align_rotated_forward_cuda(input, rois, output, pooled_height,
pooled_width, spatial_scale, sample_num,
pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise);
}

Expand All @@ -41,14 +41,14 @@ void roi_align_rotated_backward_cuda_parrots(CudaContext& ctx,
int pooled_height;
int pooled_width;
float spatial_scale;
int sample_num;
int sampling_ratio;
bool aligned;
bool clockwise;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num)
.get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise)
.done();
Expand All @@ -57,7 +57,7 @@ void roi_align_rotated_backward_cuda_parrots(CudaContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]);
auto grad_input = buildATensor(ctx, outs[0]);
roi_align_rotated_backward_cuda(grad_output, rois, grad_input, pooled_height,
pooled_width, spatial_scale, sample_num,
pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise);
}
#endif
Expand All @@ -69,14 +69,14 @@ void roi_align_rotated_forward_cpu_parrots(HostContext& ctx,
int pooled_height;
int pooled_width;
float spatial_scale;
int sample_num;
int sampling_ratio;
bool aligned;
bool clockwise;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num)
.get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise)
.done();
Expand All @@ -85,7 +85,7 @@ void roi_align_rotated_forward_cpu_parrots(HostContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
roi_align_rotated_forward_cpu(input, rois, output, pooled_height,
pooled_width, spatial_scale, sample_num,
pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise);
}

Expand All @@ -96,14 +96,14 @@ void roi_align_rotated_backward_cpu_parrots(HostContext& ctx,
int pooled_height;
int pooled_width;
float spatial_scale;
int sample_num;
int sampling_ratio;
bool aligned;
bool clockwise;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num)
.get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise)
.done();
Expand All @@ -112,15 +112,15 @@ void roi_align_rotated_backward_cpu_parrots(HostContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]);
auto grad_input = buildATensor(ctx, outs[0]);
roi_align_rotated_backward_cpu(grad_output, rois, grad_input, pooled_height,
pooled_width, spatial_scale, sample_num,
pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise);
}

PARROTS_EXTENSION_REGISTER(roi_align_rotated_forward)
.attr("pooled_height")
.attr("pooled_width")
.attr("spatial_scale")
.attr("sample_num")
.attr("sampling_ratio")
.attr("aligned")
.attr("clockwise")
.input(2)
Expand All @@ -135,7 +135,7 @@ PARROTS_EXTENSION_REGISTER(roi_align_rotated_backward)
.attr("pooled_height")
.attr("pooled_width")
.attr("spatial_scale")
.attr("sample_num")
.attr("sampling_ratio")
.attr("aligned")
.attr("clockwise")
.input(2)
Expand Down
12 changes: 6 additions & 6 deletions mmcv/ops/csrc/parrots/roi_align_rotated_pytorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,27 @@
using namespace at;

#ifdef MMCV_WITH_CUDA
void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_cuda(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale, int sample_num,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);

void roi_align_rotated_backward_cuda(Tensor grad_output, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int sample_num, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise);
#endif

void roi_align_rotated_forward_cpu(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_cpu(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale, int sample_num,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);

void roi_align_rotated_backward_cpu(Tensor grad_output, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int sample_num, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise);

#endif // ROI_ALIGN_ROTATED_PYTORCH_H
6 changes: 3 additions & 3 deletions mmcv/ops/csrc/pytorch/cpu/roi_align_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,15 @@ void roi_align_rotated_backward_cpu(Tensor top_grad, Tensor rois,
sampling_ratio, aligned, clockwise);
}

void roi_align_rotated_forward_impl(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);

void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sample_ratio, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise);
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CPU,
roi_align_rotated_forward_cpu);
Expand Down
28 changes: 14 additions & 14 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,20 +924,20 @@ REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);
REGISTER_DEVICE_IMPL(roi_align_backward_impl, CUDA, roi_align_backward_cuda);

void ROIAlignRotatedForwardCUDAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const at::Tensor input, const at::Tensor rois, const float spatial_scale,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output);

void ROIAlignRotatedBackwardCUDAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad);

void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_cuda(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
Expand All @@ -947,19 +947,19 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
AT_ERROR("wrong roi size");
}

int num_channels = features.size(1);
int data_height = features.size(2);
int data_width = features.size(3);
int num_channels = input.size(1);
int data_height = input.size(2);
int data_width = input.size(3);
ROIAlignRotatedForwardCUDAKernelLauncher(
features, rois, spatial_scale, sample_ratio, aligned, clockwise,
input, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, output);
}

void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sample_ratio, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
Expand All @@ -972,20 +972,20 @@ void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
int data_height = bottom_grad.size(2);
int data_width = bottom_grad.size(3);
ROIAlignRotatedBackwardCUDAKernelLauncher(
top_grad, rois, spatial_scale, sample_ratio, aligned, clockwise,
top_grad, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, bottom_grad);
}

void roi_align_rotated_forward_impl(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);

void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sample_ratio, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise);
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA,
roi_align_rotated_forward_cuda);
Expand Down
14 changes: 7 additions & 7 deletions mmcv/ops/csrc/pytorch/cuda/roi_align_rotated_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
#include "roi_align_rotated_cuda_kernel.cuh"

void ROIAlignRotatedForwardCUDAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const at::Tensor input, const at::Tensor rois, const float spatial_scale,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output) {
const int output_size = num_rois * pooled_height * pooled_width * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "ROIAlignRotatedLaucherForward", ([&] {
const scalar_t *bottom_data = features.data_ptr<scalar_t>();
input.scalar_type(), "ROIAlignRotatedLaucherForward", ([&] {
const scalar_t *bottom_data = input.data_ptr<scalar_t>();
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
scalar_t *top_data = output.data_ptr<scalar_t>();

roi_align_rotated_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
sample_num, aligned, clockwise, channels, height, width,
sampling_ratio, aligned, clockwise, channels, height, width,
pooled_height, pooled_width, top_data);
}));

Expand All @@ -26,7 +26,7 @@ void ROIAlignRotatedForwardCUDAKernelLauncher(

void ROIAlignRotatedBackwardCUDAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels;
Expand All @@ -37,7 +37,7 @@ void ROIAlignRotatedBackwardCUDAKernelLauncher(
scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>();
roi_align_rotated_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, top_diff, rois_data, spatial_scale, sample_num,
output_size, top_diff, rois_data, spatial_scale, sampling_ratio,
aligned, clockwise, channels, height, width, pooled_height,
pooled_width, bottom_diff);
}));
Expand Down
9 changes: 5 additions & 4 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,14 @@ Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias,

void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale, int sample_num,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);

void roi_align_rotated_backward(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale,
int sample_num, bool aligned, bool clockwise);
int sampling_ratio, bool aligned,
bool clockwise);

std::vector<torch::Tensor> dynamic_point_to_voxel_forward(
const torch::Tensor &feats, const torch::Tensor &coors,
Expand Down Expand Up @@ -736,13 +737,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("roi_align_rotated_forward", &roi_align_rotated_forward,
"roi_align_rotated forward", py::arg("input"), py::arg("rois"),
py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("spatial_scale"), py::arg("sample_num"), py::arg("aligned"),
py::arg("spatial_scale"), py::arg("sampling_ratio"), py::arg("aligned"),
py::arg("clockwise"));
m.def("roi_align_rotated_backward", &roi_align_rotated_backward,
"roi_align_rotated backward", py::arg("rois"), py::arg("grad_input"),
py::arg("grad_output"), py::arg("pooled_height"),
py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise"));
py::arg("sampling_ratio"), py::arg("aligned"), py::arg("clockwise"));
m.def("dynamic_point_to_voxel_forward", &dynamic_point_to_voxel_forward,
"dynamic_point_to_voxel_forward", py::arg("feats"), py::arg("coors"),
py::arg("reduce_type"));
Expand Down
Loading

0 comments on commit b83bdb0

Please sign in to comment.