Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor the interface for RoIAlignRotated #1662

Merged
merged 6 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions mmcv/ops/csrc/common/cuda/roi_align_rotated_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ template <typename scalar_t>
__global__ void roi_align_rotated_forward_cuda_kernel(
const int nthreads, const scalar_t *bottom_data,
const scalar_t *bottom_rois, const scalar_t spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, scalar_t *top_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
Expand Down Expand Up @@ -58,11 +58,11 @@ __global__ void roi_align_rotated_forward_cuda_kernel(
bottom_data + (roi_batch_ind * channels + c) * height * width;

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sample_num > 0)
? sample_num
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);

// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after.
Expand Down Expand Up @@ -104,7 +104,7 @@ __global__ void roi_align_rotated_forward_cuda_kernel(
template <typename scalar_t>
__global__ void roi_align_rotated_backward_cuda_kernel(
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
const scalar_t spatial_scale, const int sample_num, const bool aligned,
const scalar_t spatial_scale, const int sampling_ratio, const bool aligned,
const bool clockwise, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
Expand Down Expand Up @@ -146,11 +146,11 @@ __global__ void roi_align_rotated_backward_cuda_kernel(
const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sample_num > 0)
? sample_num
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);

// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after.
Expand Down
28 changes: 14 additions & 14 deletions mmcv/ops/csrc/parrots/roi_align_rotated_parrots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ void roi_align_rotated_forward_cuda_parrots(CudaContext& ctx,
int pooled_height;
int pooled_width;
float spatial_scale;
int sample_num;
int sampling_ratio;
bool aligned;
bool clockwise;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num)
.get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise)
.done();
Expand All @@ -30,7 +30,7 @@ void roi_align_rotated_forward_cuda_parrots(CudaContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
roi_align_rotated_forward_cuda(input, rois, output, pooled_height,
pooled_width, spatial_scale, sample_num,
pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise);
}

Expand All @@ -41,14 +41,14 @@ void roi_align_rotated_backward_cuda_parrots(CudaContext& ctx,
int pooled_height;
int pooled_width;
float spatial_scale;
int sample_num;
int sampling_ratio;
bool aligned;
bool clockwise;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num)
.get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise)
.done();
Expand All @@ -57,7 +57,7 @@ void roi_align_rotated_backward_cuda_parrots(CudaContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]);
auto grad_input = buildATensor(ctx, outs[0]);
roi_align_rotated_backward_cuda(grad_output, rois, grad_input, pooled_height,
pooled_width, spatial_scale, sample_num,
pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise);
}
#endif
Expand All @@ -69,14 +69,14 @@ void roi_align_rotated_forward_cpu_parrots(HostContext& ctx,
int pooled_height;
int pooled_width;
float spatial_scale;
int sample_num;
int sampling_ratio;
bool aligned;
bool clockwise;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num)
.get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise)
.done();
Expand All @@ -85,7 +85,7 @@ void roi_align_rotated_forward_cpu_parrots(HostContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
roi_align_rotated_forward_cpu(input, rois, output, pooled_height,
pooled_width, spatial_scale, sample_num,
pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise);
}

Expand All @@ -96,14 +96,14 @@ void roi_align_rotated_backward_cpu_parrots(HostContext& ctx,
int pooled_height;
int pooled_width;
float spatial_scale;
int sample_num;
int sampling_ratio;
bool aligned;
bool clockwise;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.get<int>("sample_num", sample_num)
.get<int>("sampling_ratio", sampling_ratio)
.get<bool>("aligned", aligned)
.get<bool>("clockwise", clockwise)
.done();
Expand All @@ -112,15 +112,15 @@ void roi_align_rotated_backward_cpu_parrots(HostContext& ctx,
const auto& rois = buildATensor(ctx, ins[1]);
auto grad_input = buildATensor(ctx, outs[0]);
roi_align_rotated_backward_cpu(grad_output, rois, grad_input, pooled_height,
pooled_width, spatial_scale, sample_num,
pooled_width, spatial_scale, sampling_ratio,
aligned, clockwise);
}

PARROTS_EXTENSION_REGISTER(roi_align_rotated_forward)
.attr("pooled_height")
.attr("pooled_width")
.attr("spatial_scale")
.attr("sample_num")
.attr("sampling_ratio")
.attr("aligned")
.attr("clockwise")
.input(2)
Expand All @@ -135,7 +135,7 @@ PARROTS_EXTENSION_REGISTER(roi_align_rotated_backward)
.attr("pooled_height")
.attr("pooled_width")
.attr("spatial_scale")
.attr("sample_num")
.attr("sampling_ratio")
.attr("aligned")
.attr("clockwise")
.input(2)
Expand Down
12 changes: 6 additions & 6 deletions mmcv/ops/csrc/parrots/roi_align_rotated_pytorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,27 @@
using namespace at;

#ifdef MMCV_WITH_CUDA
void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_cuda(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale, int sample_num,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);

void roi_align_rotated_backward_cuda(Tensor grad_output, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int sample_num, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise);
#endif

void roi_align_rotated_forward_cpu(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_cpu(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale, int sample_num,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);

void roi_align_rotated_backward_cpu(Tensor grad_output, Tensor rois,
Tensor bottom_grad, int pooled_height,
int pooled_width, float spatial_scale,
int sample_num, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise);

#endif // ROI_ALIGN_ROTATED_PYTORCH_H
6 changes: 3 additions & 3 deletions mmcv/ops/csrc/pytorch/cpu/roi_align_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,15 @@ void roi_align_rotated_backward_cpu(Tensor top_grad, Tensor rois,
sampling_ratio, aligned, clockwise);
}

void roi_align_rotated_forward_impl(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);

void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sample_ratio, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise);
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CPU,
roi_align_rotated_forward_cpu);
Expand Down
28 changes: 14 additions & 14 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,20 +924,20 @@ REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);
REGISTER_DEVICE_IMPL(roi_align_backward_impl, CUDA, roi_align_backward_cuda);

void ROIAlignRotatedForwardCUDAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const at::Tensor input, const at::Tensor rois, const float spatial_scale,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output);

void ROIAlignRotatedBackwardCUDAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad);

void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_cuda(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
Expand All @@ -947,19 +947,19 @@ void roi_align_rotated_forward_cuda(Tensor features, Tensor rois, Tensor output,
AT_ERROR("wrong roi size");
}

int num_channels = features.size(1);
int data_height = features.size(2);
int data_width = features.size(3);
int num_channels = input.size(1);
int data_height = input.size(2);
int data_width = input.size(3);
ROIAlignRotatedForwardCUDAKernelLauncher(
features, rois, spatial_scale, sample_ratio, aligned, clockwise,
input, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, output);
}

void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sample_ratio, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise) {
// Number of ROIs
int num_rois = rois.size(0);
Expand All @@ -972,20 +972,20 @@ void roi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
int data_height = bottom_grad.size(2);
int data_width = bottom_grad.size(3);
ROIAlignRotatedBackwardCUDAKernelLauncher(
top_grad, rois, spatial_scale, sample_ratio, aligned, clockwise,
top_grad, rois, spatial_scale, sampling_ratio, aligned, clockwise,
num_channels, data_height, data_width, num_rois, aligned_height,
aligned_width, bottom_grad);
}

void roi_align_rotated_forward_impl(Tensor features, Tensor rois, Tensor output,
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
int aligned_height, int aligned_width,
float spatial_scale, int sample_ratio,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);

void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
Tensor bottom_grad, int aligned_height,
int aligned_width, float spatial_scale,
int sample_ratio, bool aligned,
int sampling_ratio, bool aligned,
bool clockwise);
REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA,
roi_align_rotated_forward_cuda);
Expand Down
14 changes: 7 additions & 7 deletions mmcv/ops/csrc/pytorch/cuda/roi_align_rotated_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
#include "roi_align_rotated_cuda_kernel.cuh"

void ROIAlignRotatedForwardCUDAKernelLauncher(
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const at::Tensor input, const at::Tensor rois, const float spatial_scale,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor output) {
const int output_size = num_rois * pooled_height * pooled_width * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.scalar_type(), "ROIAlignRotatedLaucherForward", ([&] {
const scalar_t *bottom_data = features.data_ptr<scalar_t>();
input.scalar_type(), "ROIAlignRotatedLaucherForward", ([&] {
const scalar_t *bottom_data = input.data_ptr<scalar_t>();
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
scalar_t *top_data = output.data_ptr<scalar_t>();

roi_align_rotated_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
sample_num, aligned, clockwise, channels, height, width,
sampling_ratio, aligned, clockwise, channels, height, width,
pooled_height, pooled_width, top_data);
}));

Expand All @@ -26,7 +26,7 @@ void ROIAlignRotatedForwardCUDAKernelLauncher(

void ROIAlignRotatedBackwardCUDAKernelLauncher(
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
const int sample_num, const bool aligned, const bool clockwise,
const int sampling_ratio, const bool aligned, const bool clockwise,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels;
Expand All @@ -37,7 +37,7 @@ void ROIAlignRotatedBackwardCUDAKernelLauncher(
scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>();
roi_align_rotated_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, top_diff, rois_data, spatial_scale, sample_num,
output_size, top_diff, rois_data, spatial_scale, sampling_ratio,
aligned, clockwise, channels, height, width, pooled_height,
pooled_width, bottom_diff);
}));
Expand Down
9 changes: 5 additions & 4 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,14 @@ Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias,

void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale, int sample_num,
float spatial_scale, int sampling_ratio,
bool aligned, bool clockwise);

void roi_align_rotated_backward(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale,
int sample_num, bool aligned, bool clockwise);
int sampling_ratio, bool aligned,
bool clockwise);

std::vector<torch::Tensor> dynamic_point_to_voxel_forward(
const torch::Tensor &feats, const torch::Tensor &coors,
Expand Down Expand Up @@ -649,13 +650,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("roi_align_rotated_forward", &roi_align_rotated_forward,
"roi_align_rotated forward", py::arg("input"), py::arg("rois"),
py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("spatial_scale"), py::arg("sample_num"), py::arg("aligned"),
py::arg("spatial_scale"), py::arg("sampling_ratio"), py::arg("aligned"),
py::arg("clockwise"));
m.def("roi_align_rotated_backward", &roi_align_rotated_backward,
"roi_align_rotated backward", py::arg("rois"), py::arg("grad_input"),
py::arg("grad_output"), py::arg("pooled_height"),
py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise"));
py::arg("sampling_ratio"), py::arg("aligned"), py::arg("clockwise"));
m.def("dynamic_point_to_voxel_forward", &dynamic_point_to_voxel_forward,
"dynamic_point_to_voxel_forward", py::arg("feats"), py::arg("coors"),
py::arg("reduce_type"));
Expand Down
Loading