From 2f9eb12f701a005803a2443b00a69188e765bb8e Mon Sep 17 00:00:00 2001 From: dingchang Date: Thu, 21 Oct 2021 20:00:04 +0800 Subject: [PATCH 01/30] [Feature] Add roiaware pool3d ops from mmdet3d (#1382) * add ops (roiaware pool3d) in mmdet3d * refactor code * fix typo Co-authored-by: zhouzaida --- docs/understand_mmcv/ops.md | 1 + docs_zh_CN/understand_mmcv/ops.md | 1 + mmcv/ops/__init__.py | 101 +++++-- .../cuda/points_in_boxes_cuda_kernel.cuh | 93 ++++++ .../cuda/roiaware_pool3d_cuda_kernel.cuh | 268 ++++++++++++++++++ .../csrc/pytorch/cuda/points_in_boxes_cuda.cu | 62 ++++ .../csrc/pytorch/cuda/roiaware_pool3d_cuda.cu | 118 ++++++++ mmcv/ops/csrc/pytorch/points_in_boxes.cpp | 92 ++++++ mmcv/ops/csrc/pytorch/points_in_boxes_cpu.cpp | 53 ++++ mmcv/ops/csrc/pytorch/pybind.cpp | 33 +++ mmcv/ops/csrc/pytorch/roiaware_pool3d.cpp | 115 ++++++++ mmcv/ops/points_in_boxes.py | 133 +++++++++ mmcv/ops/roiaware_pool3d.py | 114 ++++++++ tests/test_ops/test_roiaware_pool3d.py | 135 +++++++++ 14 files changed, 1299 insertions(+), 20 deletions(-) create mode 100644 mmcv/ops/csrc/common/cuda/points_in_boxes_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/common/cuda/roiaware_pool3d_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/points_in_boxes.cpp create mode 100644 mmcv/ops/csrc/pytorch/points_in_boxes_cpu.cpp create mode 100644 mmcv/ops/csrc/pytorch/roiaware_pool3d.cpp create mode 100644 mmcv/ops/points_in_boxes.py create mode 100644 mmcv/ops/roiaware_pool3d.py create mode 100644 tests/test_ops/test_roiaware_pool3d.py diff --git a/docs/understand_mmcv/ops.md b/docs/understand_mmcv/ops.md index 28df9d6778..900705afa0 100644 --- a/docs/understand_mmcv/ops.md +++ b/docs/understand_mmcv/ops.md @@ -23,6 +23,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - RoIPointPool3d - RoIPool - RoIAlign +- RoIAwarePool3d - SimpleRoIAlign - SigmoidFocalLoss - SoftmaxFocalLoss diff --git a/docs_zh_CN/understand_mmcv/ops.md b/docs_zh_CN/understand_mmcv/ops.md index 425ae06a63..a45bb14862 100644 --- a/docs_zh_CN/understand_mmcv/ops.md +++ b/docs_zh_CN/understand_mmcv/ops.md @@ -23,6 +23,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 - RoIPointPool3d - RoIPool - RoIAlign +- RoIAwarePool3d - SimpleRoIAlign - SigmoidFocalLoss - SoftmaxFocalLoss diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 9eed47e5c3..338e32b652 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -34,11 +34,14 @@ from .pixel_group import pixel_group from .point_sample import (SimpleRoIAlign, point_sample, rel_roi_point_to_rel_img_point) +from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu, + points_in_boxes_part) from .points_sampler import PointsSampler from .psa_mask import PSAMask from .roi_align import RoIAlign, roi_align from .roi_align_rotated import RoIAlignRotated, roi_align_rotated from .roi_pool import RoIPool, roi_pool +from .roiaware_pool3d import RoIAwarePool3d from .roipoint_pool3d import RoIPointPool3d from .saconv import SAConv2d from .scatter_points import DynamicScatter, dynamic_scatter @@ -50,24 +53,82 @@ from .voxelize import Voxelization, voxelization __all__ = [ - 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', - 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack', - 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack', - 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss', - 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss', - 'get_compiler_version', 'get_compiling_cuda_version', - 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d', - 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack', - 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match', - 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', - 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', - 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', - 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', - 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', - 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', - 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand', - 'three_nn', 'three_interpolate', 'MultiScaleDeformableAttention', - 'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter', - 'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample', - 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation' + 'bbox_overlaps', + 'CARAFE', + 'CARAFENaive', + 'CARAFEPack', + 'carafe', + 'carafe_naive', + 'CornerPool', + 'DeformConv2d', + 'DeformConv2dPack', + 'deform_conv2d', + 'DeformRoIPool', + 'DeformRoIPoolPack', + 'ModulatedDeformRoIPoolPack', + 'deform_roi_pool', + 'SigmoidFocalLoss', + 'SoftmaxFocalLoss', + 'sigmoid_focal_loss', + 'softmax_focal_loss', + 'get_compiler_version', + 'get_compiling_cuda_version', + 'get_onnxruntime_op_path', + 'MaskedConv2d', + 'masked_conv2d', + 'ModulatedDeformConv2d', + 'ModulatedDeformConv2dPack', + 'modulated_deform_conv2d', + 'batched_nms', + 'nms', + 'soft_nms', + 'nms_match', + 'RoIAlign', + 'roi_align', + 'RoIPool', + 'roi_pool', + 'SyncBatchNorm', + 'Conv2d', + 'ConvTranspose2d', + 'Linear', + 'MaxPool2d', + 'CrissCrossAttention', + 'PSAMask', + 'point_sample', + 'rel_roi_point_to_rel_img_point', + 'SimpleRoIAlign', + 'SAConv2d', + 'TINShift', + 'tin_shift', + 'assign_score_withk', + 'box_iou_rotated', + 'RoIPointPool3d', + 'nms_rotated', + 'knn', + 'ball_query', + 'upfirdn2d', + 'FusedBiasLeakyReLU', + 'fused_bias_leakyrelu', + 'RoIAlignRotated', + 'roi_align_rotated', + 'pixel_group', + 'contour_expand', + 'three_nn', + 'three_interpolate', + 'MultiScaleDeformableAttention', + 'Voxelization', + 'voxelization', + 'dynamic_scatter', + 'DynamicScatter', + 'BorderAlign', + 'border_align', + 'gather_points', + 'furthest_point_sample', + 'furthest_point_sample_with_dist', + 'PointsSampler', + 'Correlation', + 'RoIAwarePool3d', + 'points_in_boxes_part', + 'points_in_boxes_cpu', + 'points_in_boxes_all', ] diff --git a/mmcv/ops/csrc/common/cuda/points_in_boxes_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/points_in_boxes_cuda_kernel.cuh new file mode 100644 index 0000000000..12182cc370 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/points_in_boxes_cuda_kernel.cuh @@ -0,0 +1,93 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef POINT_IN_BOXES_CUDA_KERNEL_CUH +#define POINT_IN_BOXES_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz, + T &local_x, T &local_y) { + T cosa = cos(-rz), sina = sin(-rz); + local_x = shift_x * cosa + shift_y * (-sina); + local_y = shift_x * sina + shift_y * cosa; +} + +template +__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x, + T &local_y) { + // param pt: (x, y, z) + // param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, + // cz in the bottom center + T x = pt[0], y = pt[1], z = pt[2]; + T cx = box3d[0], cy = box3d[1], cz = box3d[2]; + T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6]; + cz += z_size / + 2.0; // shift to the center since cz in box3d is the bottom center + + if (fabsf(z - cz) > z_size / 2.0) return 0; + lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y); + float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) & + (local_y > -y_size / 2.0) & (local_y < y_size / 2.0); + return in_flag; +} + +template +__global__ void points_in_boxes_part_forward_cuda_kernel( + int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts, + int *box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box DO NOT overlaps params pts: + // (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points: + // (B, npoints), default -1 + + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= batch_size || pt_idx >= pts_num) return; + + boxes += bs_idx * boxes_num * 7; + pts += bs_idx * pts_num * 3 + pt_idx * 3; + box_idx_of_points += bs_idx * pts_num + pt_idx; + + T local_x = 0, local_y = 0; + int cur_in_flag = 0; + for (int k = 0; k < boxes_num; k++) { + cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y); + if (cur_in_flag) { + box_idx_of_points[0] = k; + break; + } + } +} + +template +__global__ void points_in_boxes_all_forward_cuda_kernel( + int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts, + int *box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box DO NOT overlaps params pts: + // (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points: + // (B, npoints), default -1 + + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= batch_size || pt_idx >= pts_num) return; + + boxes += bs_idx * boxes_num * 7; + pts += bs_idx * pts_num * 3 + pt_idx * 3; + box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num; + + T local_x = 0, local_y = 0; + for (int k = 0; k < boxes_num; k++) { + const int cur_in_flag = + check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y); + if (cur_in_flag) { + box_idx_of_points[k] = 1; + } + } +} + +#endif // POINT_IN_BOXES_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/common/cuda/roiaware_pool3d_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/roiaware_pool3d_cuda_kernel.cuh new file mode 100644 index 0000000000..3b95dc7908 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/roiaware_pool3d_cuda_kernel.cuh @@ -0,0 +1,268 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROIAWARE_POOL3D_CUDA_KERNEL_CUH +#define ROIAWARE_POOL3D_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz, + T &local_x, T &local_y) { + T cosa = cos(-rz), sina = sin(-rz); + local_x = shift_x * cosa + shift_y * (-sina); + local_y = shift_x * sina + shift_y * cosa; +} + +template +__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x, + T &local_y) { + // param pt: (x, y, z) + // param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, + // cz in the bottom center + T x = pt[0], y = pt[1], z = pt[2]; + T cx = box3d[0], cy = box3d[1], cz = box3d[2]; + T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6]; + cz += z_size / + 2.0; // shift to the center since cz in box3d is the bottom center + + if (fabsf(z - cz) > z_size / 2.0) return 0; + lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y); + float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) & + (local_y > -y_size / 2.0) & (local_y < y_size / 2.0); + return in_flag; +} + +template +__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num, + int out_x, int out_y, int out_z, + const T *rois, const T *pts, + int *pts_mask) { + // params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate params pts: (npoints, 3) [x, y, z] params pts_mask: (N, + // npoints): -1 means point does not in this box, otherwise: encode (x_idxs, + // y_idxs, z_idxs) by binary bit + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + int box_idx = blockIdx.y; + if (pt_idx >= pts_num || box_idx >= boxes_num) return; + + pts += pt_idx * 3; + rois += box_idx * 7; + pts_mask += box_idx * pts_num + pt_idx; + + T local_x = 0, local_y = 0; + int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y); + + pts_mask[0] = -1; + if (cur_in_flag > 0) { + T local_z = pts[2] - rois[2]; + T x_size = rois[3], y_size = rois[4], z_size = rois[5]; + + T x_res = x_size / out_x; + T y_res = y_size / out_y; + T z_res = z_size / out_z; + + unsigned int x_idx = int((local_x + x_size / 2) / x_res); + unsigned int y_idx = int((local_y + y_size / 2) / y_res); + unsigned int z_idx = int(local_z / z_res); + + x_idx = min(max(x_idx, 0), out_x - 1); + y_idx = min(max(y_idx, 0), out_y - 1); + z_idx = min(max(z_idx, 0), out_z - 1); + + unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx; + + pts_mask[0] = idx_encoding; + } +} + +template +__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num, + int max_pts_each_voxel, int out_x, + int out_y, int out_z, + const int *pts_mask, + T *pts_idx_of_voxels) { + // params pts_mask: (N, npoints) 0 or 1 + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + + int box_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (box_idx >= boxes_num) return; + + int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel; + + for (int k = 0; k < pts_num; k++) { + if (pts_mask[box_idx * pts_num + k] != -1) { + unsigned int idx_encoding = pts_mask[box_idx * pts_num + k]; + unsigned int x_idx = (idx_encoding >> 16) & 0xFF; + unsigned int y_idx = (idx_encoding >> 8) & 0xFF; + unsigned int z_idx = idx_encoding & 0xFF; + unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel + + y_idx * out_z * max_pts_each_voxel + + z_idx * max_pts_each_voxel; + unsigned int cnt = pts_idx_of_voxels[base_offset]; + if (cnt < max_num_pts) { + pts_idx_of_voxels[base_offset + cnt + 1] = k; + pts_idx_of_voxels[base_offset]++; + } + } + } +} + +template +__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const T *pts_feature, + const int *pts_idx_of_voxels, + T *pooled_features, int *argmax) { + // params pts_feature: (npoints, C) + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), + // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C) + // params argmax: (N, out_x, out_y, out_z, C) + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x; + + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x || + y_idx >= out_y || z_idx >= out_z) + return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + + offset_base * max_pts_each_voxel; + pooled_features += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + argmax += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + int argmax_idx = -1; + float max_val = -1e50; + + int total_pts = pts_idx_of_voxels[0]; + + for (int k = 1; k <= total_pts; k++) { + if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > max_val) { + max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx]; + argmax_idx = pts_idx_of_voxels[k]; + } + } + + if (argmax_idx != -1) { + pooled_features[0] = max_val; + } + argmax[0] = argmax_idx; +} + +template +__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const T *pts_feature, + const int *pts_idx_of_voxels, + T *pooled_features) { + // params pts_feature: (npoints, C) + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), + // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C) + // params argmax: (N, out_x, out_y, out_z, C) + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x; + + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x || + y_idx >= out_y || z_idx >= out_z) + return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + + offset_base * max_pts_each_voxel; + pooled_features += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + float sum_val = 0; + int total_pts = pts_idx_of_voxels[0]; + + for (int k = 1; k <= total_pts; k++) { + sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx]; + } + + if (total_pts > 0) { + pooled_features[0] = sum_val / total_pts; + } +} + +template +__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels, + int out_x, int out_y, int out_z, + const int *argmax, + const T *grad_out, T *grad_in) { + // params argmax: (N, out_x, out_y, out_z, C) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x; + + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x || + y_idx >= out_y || z_idx >= out_z) + return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + argmax += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + grad_out += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + if (argmax[0] == -1) return; + + atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1); +} + +template +__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels, + int out_x, int out_y, int out_z, + int max_pts_each_voxel, + const int *pts_idx_of_voxels, + const T *grad_out, T *grad_in) { + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x; + + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x || + y_idx >= out_y || z_idx >= out_z) + return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + + offset_base * max_pts_each_voxel; + grad_out += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + int total_pts = pts_idx_of_voxels[0]; + float cur_grad = 1 / fmaxf(float(total_pts), 1.0); + for (int k = 1; k <= total_pts; k++) { + atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx, + grad_out[0] * cur_grad); + } +} + +#endif // ROIAWARE_POOL3D_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu new file mode 100644 index 0000000000..17e6441ba4 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu @@ -0,0 +1,62 @@ +// Modified from +// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu +// Written by Shaoshuai Shi +// All Rights Reserved 2019. + +#include + +#include "points_in_boxes_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void PointsInBoxesPartForwardCUDAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is + // the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x, + // y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default + // -1 + + at::cuda::CUDAGuard device_guard(boxes.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + boxes.scalar_type(), "points_in_boxes_part_forward_cuda_kernel", [&] { + points_in_boxes_part_forward_cuda_kernel + <<>>( + batch_size, boxes_num, pts_num, boxes.data_ptr(), + pts.data_ptr(), box_idx_of_points.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void PointsInBoxesAllForwardCUDAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box params pts: (B, npoints, 3) + // [x, y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), + // default -1 + + at::cuda::CUDAGuard device_guard(boxes.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + boxes.scalar_type(), "points_in_boxes_all_forward_cuda_kernel", [&] { + points_in_boxes_all_forward_cuda_kernel + <<>>( + batch_size, boxes_num, pts_num, boxes.data_ptr(), + pts.data_ptr(), box_idx_of_points.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu new file mode 100644 index 0000000000..2bc7c3f764 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu @@ -0,0 +1,118 @@ +// Modified from +// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu +// Written by Shaoshuai Shi +// All Rights Reserved 2019. + +#include + +#include "pytorch_cuda_helper.hpp" +#include "roiaware_pool3d_cuda_kernel.cuh" + +void RoiawarePool3dForwardCUDAKernelLauncher( + int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x, + int out_y, int out_z, const Tensor rois, const Tensor pts, + const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + // params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate params pts: (npoints, 3) [x, y, z] in LiDAR coordinate params + // pts_feature: (npoints, C) params argmax: (N, out_x, out_y, out_z, C) params + // pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) params + // pooled_features: (N, out_x, out_y, out_z, C) params pool_method: 0: + // max_pool 1: avg_pool + + at::cuda::CUDAGuard device_guard(pts_feature.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + Tensor pts_mask = + -at::ones({boxes_num, pts_num}, pts_feature.options().dtype(at::kInt)); + + dim3 blocks_mask(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rois.scalar_type(), "generate_pts_mask_for_box3d", [&] { + generate_pts_mask_for_box3d + <<>>( + boxes_num, pts_num, out_x, out_y, out_z, + rois.data_ptr(), pts.data_ptr(), + pts_mask.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); + + // TODO: Merge the collect and pool functions, SS + + dim3 blocks_collect(DIVUP(boxes_num, THREADS_PER_BLOCK)); + + AT_DISPATCH_INTEGRAL_TYPES( + pts_idx_of_voxels.scalar_type(), "collect_inside_pts_for_box3d", [&] { + collect_inside_pts_for_box3d + <<>>( + boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z, + pts_mask.data_ptr(), + pts_idx_of_voxels.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); + + dim3 blocks_pool(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels, + boxes_num); + if (pool_method == 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + pts_feature.scalar_type(), "roiaware_maxpool3d", [&] { + roiaware_maxpool3d<<>>( + boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, + out_z, pts_feature.data_ptr(), + pts_idx_of_voxels.data_ptr(), + pooled_features.data_ptr(), argmax.data_ptr()); + }); + } else if (pool_method == 1) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + pts_feature.scalar_type(), "roiaware_avgpool3d", [&] { + roiaware_avgpool3d<<>>( + boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, + out_z, pts_feature.data_ptr(), + pts_idx_of_voxels.data_ptr(), + pooled_features.data_ptr()); + }); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void RoiawarePool3dBackwardCUDAKernelLauncher( + int boxes_num, int out_x, int out_y, int out_z, int channels, + int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax, + const Tensor grad_out, Tensor grad_in, int pool_method) { + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + // params argmax: (N, out_x, out_y, out_z, C) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + // params pool_method: 0: max_pool, 1: avg_pool + + at::cuda::CUDAGuard device_guard(grad_out.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels, + boxes_num); + dim3 threads(THREADS_PER_BLOCK); + + if (pool_method == 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_in.scalar_type(), "roiaware_maxpool3d_backward", [&] { + roiaware_maxpool3d_backward<<>>( + boxes_num, channels, out_x, out_y, out_z, argmax.data_ptr(), + grad_out.data_ptr(), grad_in.data_ptr()); + }); + } else if (pool_method == 1) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_in.scalar_type(), "roiaware_avgpool3d_backward", [&] { + roiaware_avgpool3d_backward<<>>( + boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel, + pts_idx_of_voxels.data_ptr(), grad_out.data_ptr(), + grad_in.data_ptr()); + }); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/points_in_boxes.cpp b/mmcv/ops/csrc/pytorch/points_in_boxes.cpp new file mode 100644 index 0000000000..9ebeec9ab8 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/points_in_boxes.cpp @@ -0,0 +1,92 @@ +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void PointsInBoxesPartForwardCUDAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points); + +void points_in_boxes_part_forward_cuda(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + PointsInBoxesPartForwardCUDAKernelLauncher(batch_size, boxes_num, pts_num, + boxes, pts, box_idx_of_points); +}; + +void PointsInBoxesAllForwardCUDAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points); + +void points_in_boxes_all_forward_cuda(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + PointsInBoxesAllForwardCUDAKernelLauncher(batch_size, boxes_num, pts_num, + boxes, pts, box_idx_of_points); +}; +#endif + +void points_in_boxes_part_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor box_idx_of_points_tensor) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box params pts: (B, npoints, 3) + // [x, y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), + // default -1 + + if (pts_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes_tensor); + CHECK_CUDA_INPUT(pts_tensor); + CHECK_CUDA_INPUT(box_idx_of_points_tensor); + + int batch_size = boxes_tensor.size(0); + int boxes_num = boxes_tensor.size(1); + int pts_num = pts_tensor.size(1); + + const float *boxes = boxes_tensor.data_ptr(); + const float *pts = pts_tensor.data_ptr(); + int *box_idx_of_points = box_idx_of_points_tensor.data_ptr(); + + points_in_boxes_part_forward_cuda(batch_size, boxes_num, pts_num, + boxes_tensor, pts_tensor, + box_idx_of_points_tensor); +#else + AT_ERROR("points_in_boxes_part is not compiled with GPU support"); +#endif + } else { + AT_ERROR("points_in_boxes_part is not implemented on CPU"); + } +} + +void points_in_boxes_all_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor box_idx_of_points_tensor) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center. params pts: (B, npoints, 3) [x, y, z] + // in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1 + + if (pts_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes_tensor); + CHECK_CUDA_INPUT(pts_tensor); + CHECK_CUDA_INPUT(box_idx_of_points_tensor); + + int batch_size = boxes_tensor.size(0); + int boxes_num = boxes_tensor.size(1); + int pts_num = pts_tensor.size(1); + + const float *boxes = boxes_tensor.data_ptr(); + const float *pts = pts_tensor.data_ptr(); + int *box_idx_of_points = box_idx_of_points_tensor.data_ptr(); + + points_in_boxes_all_forward_cuda(batch_size, boxes_num, pts_num, + boxes_tensor, pts_tensor, + box_idx_of_points_tensor); +#else + AT_ERROR("points_in_boxes_all is not compiled with GPU support"); +#endif + } else { + AT_ERROR("points_in_boxes_all is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/points_in_boxes_cpu.cpp b/mmcv/ops/csrc/pytorch/points_in_boxes_cpu.cpp new file mode 100644 index 0000000000..c16baa4cca --- /dev/null +++ b/mmcv/ops/csrc/pytorch/points_in_boxes_cpu.cpp @@ -0,0 +1,53 @@ +#include "pytorch_cpp_helper.hpp" + +inline void lidar_to_local_coords_cpu(float shift_x, float shift_y, float rz, + float &local_x, float &local_y) { + float cosa = cos(-rz), sina = sin(-rz); + local_x = shift_x * cosa + shift_y * (-sina); + local_y = shift_x * sina + shift_y * cosa; +} + +inline int check_pt_in_box3d_cpu(const float *pt, const float *box3d, + float &local_x, float &local_y) { + // param pt: (x, y, z) + // param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, + // cz in the bottom center + float x = pt[0], y = pt[1], z = pt[2]; + float cx = box3d[0], cy = box3d[1], cz = box3d[2]; + float x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6]; + cz += z_size / + 2.0; // shift to the center since cz in box3d is the bottom center + + if (fabsf(z - cz) > z_size / 2.0) return 0; + lidar_to_local_coords_cpu(x - cx, y - cy, rz, local_x, local_y); + float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) & + (local_y > -y_size / 2.0) & (local_y < y_size / 2.0); + return in_flag; +} + +void points_in_boxes_cpu_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor pts_indices_tensor) { + // params boxes: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box DO NOT overlaps params pts: + // (npoints, 3) [x, y, z] in LiDAR coordinate params pts_indices: (N, npoints) + + CHECK_CONTIGUOUS(boxes_tensor); + CHECK_CONTIGUOUS(pts_tensor); + CHECK_CONTIGUOUS(pts_indices_tensor); + + int boxes_num = boxes_tensor.size(0); + int pts_num = pts_tensor.size(0); + + const float *boxes = boxes_tensor.data_ptr(); + const float *pts = pts_tensor.data_ptr(); + int *pts_indices = pts_indices_tensor.data_ptr(); + + float local_x = 0, local_y = 0; + for (int i = 0; i < boxes_num; i++) { + for (int j = 0; j < pts_num; j++) { + int cur_in_flag = + check_pt_in_box3d_cpu(pts + j * 3, boxes + i * 7, local_x, local_y); + pts_indices[i * pts_num + j] = cur_in_flag; + } + } +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index b9d6d6daa7..c5e3d1b697 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -296,6 +296,22 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes, const Tensor &argmax_idx, Tensor grad_input, const int pool_size); +void points_in_boxes_cpu_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor pts_indices_tensor); + +void points_in_boxes_part_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor box_idx_of_points_tensor); + +void points_in_boxes_all_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor box_idx_of_points_tensor); + +void roiaware_pool3d_forward(Tensor rois, Tensor pts, Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method); + +void roiaware_pool3d_backward(Tensor pts_idx_of_voxels, Tensor argmax, + Tensor grad_out, Tensor grad_in, int pool_method); + void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH, int dilationW, int dilation_patchH, @@ -599,6 +615,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "backward function of border_align", py::arg("grad_output"), py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), py::arg("pool_size")); + m.def("points_in_boxes_cpu_forward", &points_in_boxes_cpu_forward, + "points_in_boxes_cpu_forward", py::arg("boxes_tensor"), + py::arg("pts_tensor"), py::arg("pts_indices_tensor")); + m.def("points_in_boxes_part_forward", &points_in_boxes_part_forward, + "points_in_boxes_part_forward", py::arg("boxes_tensor"), + py::arg("pts_tensor"), py::arg("box_idx_of_points_tensor")); + m.def("points_in_boxes_all_forward", &points_in_boxes_all_forward, + "points_in_boxes_all_forward", py::arg("boxes_tensor"), + py::arg("pts_tensor"), py::arg("box_idx_of_points_tensor")); + m.def("roiaware_pool3d_forward", &roiaware_pool3d_forward, + "roiaware_pool3d_forward", py::arg("rois"), py::arg("pts"), + py::arg("pts_feature"), py::arg("argmax"), py::arg("pts_idx_of_voxels"), + py::arg("pooled_features"), py::arg("pool_method")); + m.def("roiaware_pool3d_backward", &roiaware_pool3d_backward, + "roiaware_pool3d_backward", py::arg("pts_idx_of_voxels"), + py::arg("argmax"), py::arg("grad_out"), py::arg("grad_in"), + py::arg("pool_method")); m.def("correlation_forward", &correlation_forward, "Correlation forward"); m.def("correlation_backward", &correlation_backward, "Correlation backward"); } diff --git a/mmcv/ops/csrc/pytorch/roiaware_pool3d.cpp b/mmcv/ops/csrc/pytorch/roiaware_pool3d.cpp new file mode 100644 index 0000000000..c7e267f8f0 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/roiaware_pool3d.cpp @@ -0,0 +1,115 @@ +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void RoiawarePool3dForwardCUDAKernelLauncher( + int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x, + int out_y, int out_z, const Tensor rois, const Tensor pts, + const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method); + +void roiaware_pool3d_forward_cuda(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const Tensor rois, + const Tensor pts, const Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + RoiawarePool3dForwardCUDAKernelLauncher( + boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z, + rois, pts, pts_feature, argmax, pts_idx_of_voxels, pooled_features, + pool_method); +}; + +void RoiawarePool3dBackwardCUDAKernelLauncher( + int boxes_num, int out_x, int out_y, int out_z, int channels, + int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax, + const Tensor grad_out, Tensor grad_in, int pool_method); + +void roiaware_pool3d_backward_cuda(int boxes_num, int out_x, int out_y, + int out_z, int channels, + int max_pts_each_voxel, + const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, + Tensor grad_in, int pool_method) { + RoiawarePool3dBackwardCUDAKernelLauncher( + boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel, + pts_idx_of_voxels, argmax, grad_out, grad_in, pool_method); +}; +#endif + +void roiaware_pool3d_forward(Tensor rois, Tensor pts, Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + // params rois: (N, 7) [x, y, z, x_size, y_size, z_size, ry] in LiDAR + // coordinate + // params pts: (npoints, 3) [x, y, z] in LiDAR coordinate + // params pts_feature: (npoints, C) + // params argmax: (N, out_x, out_y, out_z, C) + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + // params pooled_features: (N, out_x, out_y, out_z, C) + // params pool_method: 0: max_pool 1: avg_pool + if (pts.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(rois); + CHECK_CUDA_INPUT(pts); + CHECK_CUDA_INPUT(pts_feature); + CHECK_CUDA_INPUT(argmax); + CHECK_CUDA_INPUT(pts_idx_of_voxels); + CHECK_CUDA_INPUT(pooled_features); + + int boxes_num = rois.size(0); + int pts_num = pts.size(0); + int channels = pts_feature.size(1); + int max_pts_each_voxel = + pts_idx_of_voxels.size(4); // index 0 is the counter + int out_x = pts_idx_of_voxels.size(1); + int out_y = pts_idx_of_voxels.size(2); + int out_z = pts_idx_of_voxels.size(3); + assert((out_x < 256) && (out_y < 256) && + (out_z < 256)); // we encode index with 8bit + + roiaware_pool3d_forward_cuda(boxes_num, pts_num, channels, + max_pts_each_voxel, out_x, out_y, out_z, rois, + pts, pts_feature, argmax, pts_idx_of_voxels, + pooled_features, pool_method); +#else + AT_ERROR("roiaware_pool3d is not compiled with GPU support"); +#endif + } else { + AT_ERROR("roiaware_pool3d is not implemented on CPU"); + } +} + +void roiaware_pool3d_backward(Tensor pts_idx_of_voxels, Tensor argmax, + Tensor grad_out, Tensor grad_in, + int pool_method) { + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + // params argmax: (N, out_x, out_y, out_z, C) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + // params pool_method: 0: max_pool 1: avg_pool + + if (grad_in.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(pts_idx_of_voxels); + CHECK_CUDA_INPUT(argmax); + CHECK_CUDA_INPUT(grad_out); + CHECK_CUDA_INPUT(grad_in); + + int boxes_num = pts_idx_of_voxels.size(0); + int out_x = pts_idx_of_voxels.size(1); + int out_y = pts_idx_of_voxels.size(2); + int out_z = pts_idx_of_voxels.size(3); + int max_pts_each_voxel = + pts_idx_of_voxels.size(4); // index 0 is the counter + int channels = grad_out.size(4); + + roiaware_pool3d_backward_cuda(boxes_num, out_x, out_y, out_z, channels, + max_pts_each_voxel, pts_idx_of_voxels, argmax, + grad_out, grad_in, pool_method); +#else + AT_ERROR("roiaware_pool3d is not compiled with GPU support"); +#endif + } else { + AT_ERROR("roiaware_pool3d is not implemented on CPU"); + } +} diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py new file mode 100644 index 0000000000..4003173a53 --- /dev/null +++ b/mmcv/ops/points_in_boxes.py @@ -0,0 +1,133 @@ +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward', + 'points_in_boxes_all_forward' +]) + + +def points_in_boxes_part(points, boxes): + """Find the box in which each point is (CUDA). + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in + LiDAR/DEPTH coordinate, (x, y, z) is the bottom center + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M), default background = -1 + """ + assert points.shape[0] == boxes.shape[0], \ + 'Points and boxes should have the same batch size, ' \ + f'but got {points.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + 'boxes dimension should be 7, ' \ + f'but got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + 'points dimension should be 3, ' \ + f'but got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + + box_idxs_of_pts = points.new_zeros((batch_size, num_points), + dtype=torch.int).fill_(-1) + + # If manually put the tensor 'points' or 'boxes' on a device + # which is not the current device, some temporary variables + # will be created on the current device in the cuda op, + # and the output will be incorrect. + # Therefore, we force the current device to be the same + # as the device of the tensors if it was not. + # Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305 + # for the incorrect output before the fix. + points_device = points.get_device() + assert points_device == boxes.get_device(), \ + 'Points and boxes should be put on the same device' + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + + ext_module.points_in_boxes_part_forward(boxes.contiguous(), + points.contiguous(), + box_idxs_of_pts) + + return box_idxs_of_pts + + +def points_in_boxes_cpu(points, boxes): + """Find all boxes in which each point is (CPU). The CPU version of + :meth:`points_in_boxes_all`. + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in + LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz], + (x, y, z) is the bottom center. + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0. + """ + assert points.shape[0] == boxes.shape[0], \ + 'Points and boxes should have the same batch size, ' \ + f'but got {points.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + 'boxes dimension should be 7, ' \ + f'but got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + 'points dimension should be 3, ' \ + f'but got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + num_boxes = boxes.shape[1] + + point_indices = points.new_zeros((batch_size, num_boxes, num_points), + dtype=torch.int) + for b in range(batch_size): + ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(), + points[b].float().contiguous(), + point_indices[b]) + point_indices = point_indices.transpose(1, 2) + + return point_indices + + +def points_in_boxes_all(points, boxes): + """Find all boxes in which each point is (CUDA). + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz], + (x, y, z) is the bottom center. + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0. + """ + assert boxes.shape[0] == points.shape[0], \ + 'Points and boxes should have the same batch size, ' \ + f'but got {boxes.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + 'boxes dimension should be 7, ' \ + f'but got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + 'points dimension should be 3, ' \ + f'but got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + num_boxes = boxes.shape[1] + + box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes), + dtype=torch.int).fill_(0) + + # Same reason as line 25-32 + points_device = points.get_device() + assert points_device == boxes.get_device(), \ + 'Points and boxes should be put on the same device' + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + + ext_module.points_in_boxes_all_forward(boxes.contiguous(), + points.contiguous(), + box_idxs_of_pts) + + return box_idxs_of_pts diff --git a/mmcv/ops/roiaware_pool3d.py b/mmcv/ops/roiaware_pool3d.py new file mode 100644 index 0000000000..e593c7052f --- /dev/null +++ b/mmcv/ops/roiaware_pool3d.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn as nn +from torch.autograd import Function + +import mmcv +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward']) + + +class RoIAwarePool3d(nn.Module): + """Encode the geometry-specific features of each 3D proposal. + + Please refer to `PartA2 `_ for more + details. + + Args: + out_size (int or tuple): The size of output features. n or + [n1, n2, n3]. + max_pts_per_voxel (int, optional): The maximum number of points per + voxel. Default: 128. + mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'. + Default: 'max'. + """ + + def __init__(self, out_size, max_pts_per_voxel=128, mode='max'): + super().__init__() + + self.out_size = out_size + self.max_pts_per_voxel = max_pts_per_voxel + assert mode in ['max', 'avg'] + pool_mapping = {'max': 0, 'avg': 1} + self.mode = pool_mapping[mode] + + def forward(self, rois, pts, pts_feature): + """ + Args: + rois (torch.Tensor): [N, 7], in LiDAR coordinate, + (x, y, z) is the bottom center of rois. + pts (torch.Tensor): [npoints, 3], coordinates of input points. + pts_feature (torch.Tensor): [npoints, C], features of input points. + + Returns: + pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C] + """ + + return RoIAwarePool3dFunction.apply(rois, pts, pts_feature, + self.out_size, + self.max_pts_per_voxel, self.mode) + + +class RoIAwarePool3dFunction(Function): + + @staticmethod + def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel, + mode): + """ + Args: + rois (torch.Tensor): [N, 7], in LiDAR coordinate, + (x, y, z) is the bottom center of rois. + pts (torch.Tensor): [npoints, 3], coordinates of input points. + pts_feature (torch.Tensor): [npoints, C], features of input points. + out_size (int or tuple): The size of output features. n or + [n1, n2, n3]. + max_pts_per_voxel (int): The maximum number of points per voxel. + Default: 128. + mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average + pool). + + Returns: + pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C], output + pooled features. + """ + + if isinstance(out_size, int): + out_x = out_y = out_z = out_size + else: + assert len(out_size) == 3 + assert mmcv.is_tuple_of(out_size, int) + out_x, out_y, out_z = out_size + + num_rois = rois.shape[0] + num_channels = pts_feature.shape[-1] + num_pts = pts.shape[0] + + pooled_features = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, num_channels)) + argmax = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int) + pts_idx_of_voxels = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, max_pts_per_voxel), + dtype=torch.int) + + ext_module.roiaware_pool3d_forward(rois, pts, pts_feature, argmax, + pts_idx_of_voxels, pooled_features, + mode) + + ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode, + num_pts, num_channels) + return pooled_features + + @staticmethod + def backward(ctx, grad_out): + ret = ctx.roiaware_pool3d_for_backward + pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret + + grad_in = grad_out.new_zeros((num_pts, num_channels)) + ext_module.roiaware_pool3d_backward(pts_idx_of_voxels, argmax, + grad_out.contiguous(), grad_in, + mode) + + return None, None, grad_in, None, None, None diff --git a/tests/test_ops/test_roiaware_pool3d.py b/tests/test_ops/test_roiaware_pool3d.py new file mode 100644 index 0000000000..1d63e398da --- /dev/null +++ b/tests/test_ops/test_roiaware_pool3d.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest +import torch + +from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu, + points_in_boxes_part) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_RoIAwarePool3d(): + roiaware_pool3d_max = RoIAwarePool3d( + out_size=4, max_pts_per_voxel=128, mode='max') + roiaware_pool3d_avg = RoIAwarePool3d( + out_size=4, max_pts_per_voxel=128, mode='avg') + rois = torch.tensor( + [[1.0, 2.0, 3.0, 5.0, 4.0, 6.0, -0.3 - np.pi / 2], + [-10.0, 23.0, 16.0, 20.0, 10.0, 20.0, -0.5 - np.pi / 2]], + dtype=torch.float32).cuda( + ) # boxes (m, 7) with bottom center in lidar coordinate + pts = torch.tensor( + [[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], + [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], + [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]], + dtype=torch.float32).cuda() # points (n, 3) in lidar coordinate + pts_feature = pts.clone() + + pooled_features_max = roiaware_pool3d_max( + rois=rois, pts=pts, pts_feature=pts_feature) + assert pooled_features_max.shape == torch.Size([2, 4, 4, 4, 3]) + assert torch.allclose(pooled_features_max.sum(), + torch.tensor(51.100).cuda(), 1e-3) + + pooled_features_avg = roiaware_pool3d_avg( + rois=rois, pts=pts, pts_feature=pts_feature) + assert pooled_features_avg.shape == torch.Size([2, 4, 4, 4, 3]) + assert torch.allclose(pooled_features_avg.sum(), + torch.tensor(49.750).cuda(), 1e-3) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_points_in_boxes_part(): + boxes = torch.tensor( + [[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3]], + [[-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]], + dtype=torch.float32).cuda( + ) # boxes (b, t, 7) with bottom center in lidar coordinate + pts = torch.tensor( + [[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], + [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2]], + [[3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], [-21.3, -52, -5], + [0, 0, 0], [6, 7, 8], [-2, -3, -4], [6, 4, 9]]], + dtype=torch.float32).cuda() # points (b, m, 3) in lidar coordinate + + point_indices = points_in_boxes_part(points=pts, boxes=boxes) + expected_point_indices = torch.tensor( + [[0, 0, 0, 0, 0, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1]], + dtype=torch.int32).cuda() + assert point_indices.shape == torch.Size([2, 8]) + assert (point_indices == expected_point_indices).all() + + boxes = torch.tensor([[[0.0, 0.0, 0.0, 1.0, 20.0, 1.0, 0.523598]]], + dtype=torch.float32).cuda() # 30 degrees + pts = torch.tensor( + [[[4, 6.928, 0], [6.928, 4, 0], [4, -6.928, 0], [6.928, -4, 0], + [-4, 6.928, 0], [-6.928, 4, 0], [-4, -6.928, 0], [-6.928, -4, 0]]], + dtype=torch.float32).cuda() + point_indices = points_in_boxes_part(points=pts, boxes=boxes) + expected_point_indices = torch.tensor([[-1, -1, 0, -1, 0, -1, -1, -1]], + dtype=torch.int32).cuda() + assert (point_indices == expected_point_indices).all() + + +def test_points_in_boxes_cpu(): + boxes = torch.tensor( + [[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3], + [-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]], + dtype=torch.float32 + ) # boxes (m, 7) with bottom center in lidar coordinate + pts = torch.tensor( + [[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], + [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [ + -16, -18, 9 + ], [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]]], + dtype=torch.float32) # points (n, 3) in lidar coordinate + + point_indices = points_in_boxes_cpu(points=pts, boxes=boxes) + expected_point_indices = torch.tensor( + [[[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [0, 1], [0, 0], [0, 0], + [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]], + dtype=torch.int32) + assert point_indices.shape == torch.Size([1, 15, 2]) + assert (point_indices == expected_point_indices).all() + + boxes = torch.tensor([[[0.0, 0.0, 0.0, 1.0, 20.0, 1.0, 0.523598]]], + dtype=torch.float32) # 30 degrees + pts = torch.tensor( + [[[4, 6.928, 0], [6.928, 4, 0], [4, -6.928, 0], [6.928, -4, 0], + [-4, 6.928, 0], [-6.928, 4, 0], [-4, -6.928, 0], [-6.928, -4, 0]]], + dtype=torch.float32) + point_indices = points_in_boxes_cpu(points=pts, boxes=boxes) + expected_point_indices = torch.tensor( + [[[0], [0], [1], [0], [1], [0], [0], [0]]], dtype=torch.int32) + assert (point_indices == expected_point_indices).all() + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_points_in_boxes_all(): + + boxes = torch.tensor( + [[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3], + [-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]], + dtype=torch.float32).cuda( + ) # boxes (m, 7) with bottom center in lidar coordinate + pts = torch.tensor( + [[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], + [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [ + -16, -18, 9 + ], [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]]], + dtype=torch.float32).cuda() # points (n, 3) in lidar coordinate + + point_indices = points_in_boxes_all(points=pts, boxes=boxes) + expected_point_indices = torch.tensor( + [[[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [0, 1], [0, 0], [0, 0], + [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]], + dtype=torch.int32).cuda() + assert point_indices.shape == torch.Size([1, 15, 2]) + assert (point_indices == expected_point_indices).all() From 36e7aa1fa918a368ea57e14385bfc829dd93ce87 Mon Sep 17 00:00:00 2001 From: dingchang Date: Fri, 22 Oct 2021 10:46:55 +0800 Subject: [PATCH 02/30] [Feature] Add iou3d op from mmdet3d (#1356) * add ops (iou3d) in mmdet3d * add unit test * refactor code * refactor code * refactor code * refactor code * refactor code Co-authored-by: zhouzaida --- mmcv/ops/__init__.py | 102 ++--- .../csrc/common/cuda/iou3d_cuda_kernel.cuh | 369 ++++++++++++++++++ mmcv/ops/csrc/common/pytorch_cpp_helper.hpp | 2 + mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu | 86 ++++ mmcv/ops/csrc/pytorch/iou3d.cpp | 244 ++++++++++++ mmcv/ops/csrc/pytorch/pybind.cpp | 22 ++ mmcv/ops/iou3d.py | 83 ++++ tests/test_ops/test_iou3d.py | 58 +++ 8 files changed, 888 insertions(+), 78 deletions(-) create mode 100644 mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/iou3d.cpp create mode 100644 mmcv/ops/iou3d.py create mode 100644 tests/test_ops/test_iou3d.py diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 338e32b652..b5a06c7614 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -24,6 +24,7 @@ from .gather_points import gather_points from .info import (get_compiler_version, get_compiling_cuda_version, get_onnxruntime_op_path) +from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev from .knn import knn from .masked_conv import MaskedConv2d, masked_conv2d from .modulated_deform_conv import (ModulatedDeformConv2d, @@ -53,82 +54,27 @@ from .voxelize import Voxelization, voxelization __all__ = [ - 'bbox_overlaps', - 'CARAFE', - 'CARAFENaive', - 'CARAFEPack', - 'carafe', - 'carafe_naive', - 'CornerPool', - 'DeformConv2d', - 'DeformConv2dPack', - 'deform_conv2d', - 'DeformRoIPool', - 'DeformRoIPoolPack', - 'ModulatedDeformRoIPoolPack', - 'deform_roi_pool', - 'SigmoidFocalLoss', - 'SoftmaxFocalLoss', - 'sigmoid_focal_loss', - 'softmax_focal_loss', - 'get_compiler_version', - 'get_compiling_cuda_version', - 'get_onnxruntime_op_path', - 'MaskedConv2d', - 'masked_conv2d', - 'ModulatedDeformConv2d', - 'ModulatedDeformConv2dPack', - 'modulated_deform_conv2d', - 'batched_nms', - 'nms', - 'soft_nms', - 'nms_match', - 'RoIAlign', - 'roi_align', - 'RoIPool', - 'roi_pool', - 'SyncBatchNorm', - 'Conv2d', - 'ConvTranspose2d', - 'Linear', - 'MaxPool2d', - 'CrissCrossAttention', - 'PSAMask', - 'point_sample', - 'rel_roi_point_to_rel_img_point', - 'SimpleRoIAlign', - 'SAConv2d', - 'TINShift', - 'tin_shift', - 'assign_score_withk', - 'box_iou_rotated', - 'RoIPointPool3d', - 'nms_rotated', - 'knn', - 'ball_query', - 'upfirdn2d', - 'FusedBiasLeakyReLU', - 'fused_bias_leakyrelu', - 'RoIAlignRotated', - 'roi_align_rotated', - 'pixel_group', - 'contour_expand', - 'three_nn', - 'three_interpolate', - 'MultiScaleDeformableAttention', - 'Voxelization', - 'voxelization', - 'dynamic_scatter', - 'DynamicScatter', - 'BorderAlign', - 'border_align', - 'gather_points', - 'furthest_point_sample', - 'furthest_point_sample_with_dist', - 'PointsSampler', - 'Correlation', - 'RoIAwarePool3d', - 'points_in_boxes_part', - 'points_in_boxes_cpu', - 'points_in_boxes_all', + 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', + 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack', + 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack', + 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss', + 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss', + 'get_compiler_version', 'get_compiling_cuda_version', + 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d', + 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack', + 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match', + 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', + 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', + 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', + 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', + 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', + 'upfirdn2d', 'FusedBiasLeakyReLU', 'boxes_iou_bev', 'nms_bev', + 'nms_normal_bev', 'fused_bias_leakyrelu', 'RoIAlignRotated', + 'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn', + 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', + 'border_align', 'gather_points', 'furthest_point_sample', + 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation', + 'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter', + 'RoIAwarePool3d', 'points_in_boxes_part', 'points_in_boxes_cpu', + 'points_in_boxes_all' ] diff --git a/mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh new file mode 100644 index 0000000000..4e261cbd0c --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh @@ -0,0 +1,369 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef IOU3D_CUDA_KERNEL_CUH +#define IOU3D_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +const int THREADS_PER_BLOCK_IOU3D = 16; +const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; +__device__ const float EPS = 1e-8; + +struct Point { + float x, y; + __device__ Point() {} + __device__ Point(double _x, double _y) { x = _x, y = _y; } + + __device__ void set(float _x, float _y) { + x = _x; + y = _y; + } + + __device__ Point operator+(const Point &b) const { + return Point(x + b.x, y + b.y); + } + + __device__ Point operator-(const Point &b) const { + return Point(x - b.x, y - b.y); + } +}; + +__device__ inline float cross(const Point &a, const Point &b) { + return a.x * b.y - a.y * b.x; +} + +__device__ inline float cross(const Point &p1, const Point &p2, + const Point &p0) { + return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y); +} + +__device__ int check_rect_cross(const Point &p1, const Point &p2, + const Point &q1, const Point &q2) { + int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) && + min(q1.x, q2.x) <= max(p1.x, p2.x) && + min(p1.y, p2.y) <= max(q1.y, q2.y) && + min(q1.y, q2.y) <= max(p1.y, p2.y); + return ret; +} + +__device__ inline int check_in_box2d(const float *box, const Point &p) { + // params: box (5) [x1, y1, x2, y2, angle] + const float MARGIN = 1e-5; + + float center_x = (box[0] + box[2]) / 2; + float center_y = (box[1] + box[3]) / 2; + float angle_cos = cos(-box[4]), + angle_sin = + sin(-box[4]); // rotate the point in the opposite direction of box + float rot_x = + (p.x - center_x) * angle_cos - (p.y - center_y) * angle_sin + center_x; + float rot_y = + (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y; + + return (rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN && + rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN); +} + +__device__ inline int intersection(const Point &p1, const Point &p0, + const Point &q1, const Point &q0, + Point &ans_point) { + // fast exclusion + if (check_rect_cross(p0, p1, q0, q1) == 0) return 0; + + // check cross standing + float s1 = cross(q0, p1, p0); + float s2 = cross(p1, q1, p0); + float s3 = cross(p0, q1, q0); + float s4 = cross(q1, p1, q0); + + if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0; + + // calculate intersection of two lines + float s5 = cross(q1, p1, p0); + if (fabs(s5 - s1) > EPS) { + ans_point.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1); + ans_point.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1); + + } else { + float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y; + float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y; + float D = a0 * b1 - a1 * b0; + + ans_point.x = (b0 * c1 - b1 * c0) / D; + ans_point.y = (a1 * c0 - a0 * c1) / D; + } + + return 1; +} + +__device__ inline void rotate_around_center(const Point ¢er, + const float angle_cos, + const float angle_sin, Point &p) { + float new_x = + (p.x - center.x) * angle_cos - (p.y - center.y) * angle_sin + center.x; + float new_y = + (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y; + p.set(new_x, new_y); +} + +__device__ inline int point_cmp(const Point &a, const Point &b, + const Point ¢er) { + return atan2(a.y - center.y, a.x - center.x) > + atan2(b.y - center.y, b.x - center.x); +} + +__device__ inline float box_overlap(const float *box_a, const float *box_b) { + // params: box_a (5) [x1, y1, x2, y2, angle] + // params: box_b (5) [x1, y1, x2, y2, angle] + + float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3], + a_angle = box_a[4]; + float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3], + b_angle = box_b[4]; + + Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2); + Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2); + + Point box_a_corners[5]; + box_a_corners[0].set(a_x1, a_y1); + box_a_corners[1].set(a_x2, a_y1); + box_a_corners[2].set(a_x2, a_y2); + box_a_corners[3].set(a_x1, a_y2); + + Point box_b_corners[5]; + box_b_corners[0].set(b_x1, b_y1); + box_b_corners[1].set(b_x2, b_y1); + box_b_corners[2].set(b_x2, b_y2); + box_b_corners[3].set(b_x1, b_y2); + + // get oriented corners + float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle); + float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle); + + for (int k = 0; k < 4; k++) { + rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]); + rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]); + } + + box_a_corners[4] = box_a_corners[0]; + box_b_corners[4] = box_b_corners[0]; + + // get intersection of lines + Point cross_points[16]; + Point poly_center; + int cnt = 0, flag = 0; + + poly_center.set(0, 0); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + flag = intersection(box_a_corners[i + 1], box_a_corners[i], + box_b_corners[j + 1], box_b_corners[j], + cross_points[cnt]); + if (flag) { + poly_center = poly_center + cross_points[cnt]; + cnt++; + } + } + } + + // check corners + for (int k = 0; k < 4; k++) { + if (check_in_box2d(box_a, box_b_corners[k])) { + poly_center = poly_center + box_b_corners[k]; + cross_points[cnt] = box_b_corners[k]; + cnt++; + } + if (check_in_box2d(box_b, box_a_corners[k])) { + poly_center = poly_center + box_a_corners[k]; + cross_points[cnt] = box_a_corners[k]; + cnt++; + } + } + + poly_center.x /= cnt; + poly_center.y /= cnt; + + // sort the points of polygon + Point temp; + for (int j = 0; j < cnt - 1; j++) { + for (int i = 0; i < cnt - j - 1; i++) { + if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) { + temp = cross_points[i]; + cross_points[i] = cross_points[i + 1]; + cross_points[i + 1] = temp; + } + } + } + + // get the overlap areas + float area = 0; + for (int k = 0; k < cnt - 1; k++) { + area += cross(cross_points[k] - cross_points[0], + cross_points[k + 1] - cross_points[0]); + } + + return fabs(area) / 2.0; +} + +__device__ inline float iou_bev(const float *box_a, const float *box_b) { + // params: box_a (5) [x1, y1, x2, y2, angle] + // params: box_b (5) [x1, y1, x2, y2, angle] + float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]); + float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]); + float s_overlap = box_overlap(box_a, box_b); + return s_overlap / fmaxf(sa + sb - s_overlap, EPS); +} + +__global__ void iou3d_boxes_overlap_bev_forward_cuda_kernel( + const int num_a, const float *boxes_a, const int num_b, + const float *boxes_b, float *ans_overlap) { + const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; + const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; + + if (a_idx >= num_a || b_idx >= num_b) { + return; + } + const float *cur_box_a = boxes_a + a_idx * 5; + const float *cur_box_b = boxes_b + b_idx * 5; + float s_overlap = box_overlap(cur_box_a, cur_box_b); + ans_overlap[a_idx * num_b + b_idx] = s_overlap; +} + +__global__ void iou3d_boxes_iou_bev_forward_cuda_kernel(const int num_a, + const float *boxes_a, + const int num_b, + const float *boxes_b, + float *ans_iou) { + const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; + const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; + + if (a_idx >= num_a || b_idx >= num_b) { + return; + } + + const float *cur_box_a = boxes_a + a_idx * 5; + const float *cur_box_b = boxes_b + b_idx * 5; + float cur_iou_bev = iou_bev(cur_box_a, cur_box_b); + ans_iou[a_idx * num_b + b_idx] = cur_iou_bev; +} + +__global__ void nms_forward_cuda_kernel(const int boxes_num, + const float nms_overlap_thresh, + const float *boxes, + unsigned long long *mask) { + // params: boxes (N, 5) [x1, y1, x2, y2, ry] + // params: mask (N, N/THREADS_PER_BLOCK_NMS) + + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + + __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; + + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; + const float *cur_box = boxes + cur_box_idx * 5; + + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +__device__ inline float iou_normal(float const *const a, float const *const b) { + float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); + float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); + float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0]) * (a[3] - a[1]); + float Sb = (b[2] - b[0]) * (b[3] - b[1]); + return interS / fmaxf(Sa + Sb - interS, EPS); +} + +__global__ void nms_normal_forward_cuda_kernel(const int boxes_num, + const float nms_overlap_thresh, + const float *boxes, + unsigned long long *mask) { + // params: boxes (N, 5) [x1, y1, x2, y2, ry] + // params: mask (N, N/THREADS_PER_BLOCK_NMS) + + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + + __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; + + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; + const float *cur_box = boxes + cur_box_idx * 5; + + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +#endif // IOU3D_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp b/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp index b812e62713..c7f9f35b7b 100644 --- a/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp @@ -6,6 +6,8 @@ using namespace at; +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + #define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CPU(x) \ diff --git a/mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu new file mode 100644 index 0000000000..0643c16044 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu @@ -0,0 +1,86 @@ +// Modified from +// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu + +/* +3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others) +Written by Shaoshuai Shi +All Rights Reserved 2019-2020. +*/ + +#include + +#include "iou3d_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a, + const Tensor boxes_a, + const int num_b, + const Tensor boxes_b, + Tensor ans_overlap) { + at::cuda::CUDAGuard device_guard(boxes_a.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK_IOU3D), + DIVUP(num_a, THREADS_PER_BLOCK_IOU3D)); + dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D); + + iou3d_boxes_overlap_bev_forward_cuda_kernel<<>>( + num_a, boxes_a.data_ptr(), num_b, boxes_b.data_ptr(), + ans_overlap.data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void IoU3DBoxesIoUBevForwardCUDAKernelLauncher(const int num_a, + const Tensor boxes_a, + const int num_b, + const Tensor boxes_b, + Tensor ans_iou) { + at::cuda::CUDAGuard device_guard(boxes_a.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK_IOU3D), + DIVUP(num_a, THREADS_PER_BLOCK_IOU3D)); + dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D); + + iou3d_boxes_iou_bev_forward_cuda_kernel<<>>( + num_a, boxes_a.data_ptr(), num_b, boxes_b.data_ptr(), + ans_iou.data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long *mask, int boxes_num, + float nms_overlap_thresh) { + at::cuda::CUDAGuard device_guard(boxes.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), + DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); + dim3 threads(THREADS_PER_BLOCK_NMS); + + nms_forward_cuda_kernel<<>>( + boxes_num, nms_overlap_thresh, boxes.data_ptr(), mask); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long *mask, + int boxes_num, + float nms_overlap_thresh) { + at::cuda::CUDAGuard device_guard(boxes.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), + DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); + dim3 threads(THREADS_PER_BLOCK_NMS); + + nms_normal_forward_cuda_kernel<<>>( + boxes_num, nms_overlap_thresh, boxes.data_ptr(), mask); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/iou3d.cpp b/mmcv/ops/csrc/pytorch/iou3d.cpp new file mode 100644 index 0000000000..eecfdf224a --- /dev/null +++ b/mmcv/ops/csrc/pytorch/iou3d.cpp @@ -0,0 +1,244 @@ +// Modified from +// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp + +/* +3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others) +Written by Shaoshuai Shi +All Rights Reserved 2019-2020. +*/ + +#include "pytorch_cpp_helper.hpp" + +const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; + +#ifdef MMCV_WITH_CUDA +#include +#include + +#define CHECK_ERROR(state) \ + { gpuAssert((state), __FILE__, __LINE__); } +inline void gpuAssert(cudaError_t code, const char *file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + if (abort) exit(code); + } +} + +void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a, + const Tensor boxes_a, + const int num_b, + const Tensor boxes_b, + Tensor ans_overlap); +void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a, + const int num_b, const Tensor boxes_b, + Tensor ans_overlap) { + IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b, + ans_overlap); +}; + +void IoU3DBoxesIoUBevForwardCUDAKernelLauncher(const int num_a, + const Tensor boxes_a, + const int num_b, + const Tensor boxes_b, + Tensor ans_iou); +void iou3d_boxes_iou_bev_forward_cuda(const int num_a, const Tensor boxes_a, + const int num_b, const Tensor boxes_b, + Tensor ans_iou) { + IoU3DBoxesIoUBevForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b, + ans_iou); +}; + +void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long *mask, int boxes_num, + float nms_overlap_thresh); + +void iou3d_nms_forward_cuda(const Tensor boxes, unsigned long long *mask, + int boxes_num, float nms_overlap_thresh) { + IoU3DNMSForwardCUDAKernelLauncher(boxes, mask, boxes_num, nms_overlap_thresh); +}; + +void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long *mask, + int boxes_num, + float nms_overlap_thresh); + +void iou3d_nms_normal_forward_cuda(const Tensor boxes, unsigned long long *mask, + int boxes_num, float nms_overlap_thresh) { + IoU3DNMSNormalForwardCUDAKernelLauncher(boxes, mask, boxes_num, + nms_overlap_thresh); +}; +#endif + +void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b, + Tensor ans_overlap) { + // params boxes_a: (N, 5) [x1, y1, x2, y2, ry] + // params boxes_b: (M, 5) + // params ans_overlap: (N, M) + + if (boxes_a.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes_a); + CHECK_CUDA_INPUT(boxes_b); + CHECK_CUDA_INPUT(ans_overlap); + + int num_a = boxes_a.size(0); + int num_b = boxes_b.size(0); + + iou3d_boxes_overlap_bev_forward_cuda(num_a, boxes_a, num_b, boxes_b, + ans_overlap); +#else + AT_ERROR("iou3d_boxes_overlap_bev is not compiled with GPU support"); +#endif + } else { + AT_ERROR("iou3d_boxes_overlap_bev is not implemented on CPU"); + } +} + +void iou3d_boxes_iou_bev_forward(Tensor boxes_a, Tensor boxes_b, + Tensor ans_iou) { + // params boxes_a: (N, 5) [x1, y1, x2, y2, ry] + // params boxes_b: (M, 5) + // params ans_overlap: (N, M) + + if (boxes_a.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes_a); + CHECK_CUDA_INPUT(boxes_b); + CHECK_CUDA_INPUT(ans_iou); + + int num_a = boxes_a.size(0); + int num_b = boxes_b.size(0); + + iou3d_boxes_iou_bev_forward_cuda(num_a, boxes_a, num_b, boxes_b, ans_iou); +#else + AT_ERROR("iou3d_boxes_iou_bev is not compiled with GPU support"); +#endif + } else { + AT_ERROR("iou3d_boxes_iou_bev is not implemented on CPU"); + } +} + +int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh) { + // params boxes: (N, 5) [x1, y1, x2, y2, ry] + // params keep: (N) + + if (boxes.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes); + CHECK_CONTIGUOUS(keep); + + int boxes_num = boxes.size(0); + int64_t *keep_data = keep.data_ptr(); + + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + + unsigned long long *mask_data = NULL; + CHECK_ERROR( + cudaMalloc((void **)&mask_data, + boxes_num * col_blocks * sizeof(unsigned long long))); + iou3d_nms_forward_cuda(boxes, mask_data, boxes_num, nms_overlap_thresh); + + // unsigned long long mask_cpu[boxes_num * col_blocks]; + // unsigned long long *mask_cpu = new unsigned long long [boxes_num * + // col_blocks]; + std::vector mask_cpu(boxes_num * col_blocks); + + // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); + CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, + boxes_num * col_blocks * sizeof(unsigned long long), + cudaMemcpyDeviceToHost)); + + cudaFree(mask_data); + + unsigned long long *remv_cpu = new unsigned long long[col_blocks](); + + int num_to_keep = 0; + + for (int i = 0; i < boxes_num; i++) { + int nblock = i / THREADS_PER_BLOCK_NMS; + int inblock = i % THREADS_PER_BLOCK_NMS; + + if (!(remv_cpu[nblock] & (1ULL << inblock))) { + keep_data[num_to_keep++] = i; + unsigned long long *p = &mask_cpu[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv_cpu[j] |= p[j]; + } + } + } + delete[] remv_cpu; + if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); + + return num_to_keep; + +#else + AT_ERROR("iou3d_nms is not compiled with GPU support"); +#endif + } else { + AT_ERROR("iou3d_nms is not implemented on CPU"); + } +} + +int iou3d_nms_normal_forward(Tensor boxes, Tensor keep, + float nms_overlap_thresh) { + // params boxes: (N, 5) [x1, y1, x2, y2, ry] + // params keep: (N) + + if (boxes.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes); + CHECK_CONTIGUOUS(keep); + + int boxes_num = boxes.size(0); + int64_t *keep_data = keep.data_ptr(); + + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + + unsigned long long *mask_data = NULL; + CHECK_ERROR( + cudaMalloc((void **)&mask_data, + boxes_num * col_blocks * sizeof(unsigned long long))); + iou3d_nms_normal_forward_cuda(boxes, mask_data, boxes_num, + nms_overlap_thresh); + + // unsigned long long mask_cpu[boxes_num * col_blocks]; + // unsigned long long *mask_cpu = new unsigned long long [boxes_num * + // col_blocks]; + std::vector mask_cpu(boxes_num * col_blocks); + + CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, + boxes_num * col_blocks * sizeof(unsigned long long), + cudaMemcpyDeviceToHost)); + + cudaFree(mask_data); + + unsigned long long *remv_cpu = new unsigned long long[col_blocks](); + + int num_to_keep = 0; + + for (int i = 0; i < boxes_num; i++) { + int nblock = i / THREADS_PER_BLOCK_NMS; + int inblock = i % THREADS_PER_BLOCK_NMS; + + if (!(remv_cpu[nblock] & (1ULL << inblock))) { + keep_data[num_to_keep++] = i; + unsigned long long *p = &mask_cpu[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv_cpu[j] |= p[j]; + } + } + } + delete[] remv_cpu; + if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); + + return num_to_keep; + +#else + AT_ERROR("iou3d_nms_normal is not compiled with GPU support"); +#endif + } else { + AT_ERROR("iou3d_nms_normal is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index c5e3d1b697..7b39a5e443 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -105,6 +105,17 @@ void three_nn_forward(int b, int n, int m, Tensor unknown_tensor, void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, const bool aligned, const int offset); +void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b, + Tensor ans_overlap); + +void iou3d_boxes_iou_bev_forward(Tensor boxes_a, Tensor boxes_b, + Tensor ans_iou); + +int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh); + +int iou3d_nms_normal_forward(Tensor boxes, Tensor keep, + float nms_overlap_thresh); + void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor); @@ -442,6 +453,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), py::arg("aligned"), py::arg("offset")); + m.def("iou3d_boxes_overlap_bev_forward", &iou3d_boxes_overlap_bev_forward, + "iou3d_boxes_overlap_bev_forward", py::arg("boxes_a"), + py::arg("boxes_b"), py::arg("ans_overlap")); + m.def("iou3d_boxes_iou_bev_forward", &iou3d_boxes_iou_bev_forward, + "iou3d_boxes_iou_bev_forward", py::arg("boxes_a"), py::arg("boxes_b"), + py::arg("ans_iou")); + m.def("iou3d_nms_forward", &iou3d_nms_forward, "iou3d_nms_forward", + py::arg("boxes"), py::arg("keep"), py::arg("nms_overlap_thresh")); + m.def("iou3d_nms_normal_forward", &iou3d_nms_normal_forward, + "iou3d_nms_normal_forward", py::arg("boxes"), py::arg("keep"), + py::arg("nms_overlap_thresh")); m.def("furthest_point_sampling_forward", &furthest_point_sampling_forward, "furthest_point_sampling_forward", py::arg("b"), py::arg("n"), py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"), diff --git a/mmcv/ops/iou3d.py b/mmcv/ops/iou3d.py new file mode 100644 index 0000000000..f22a9c82c0 --- /dev/null +++ b/mmcv/ops/iou3d.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward', + 'iou3d_nms_normal_forward' +]) + + +def boxes_iou_bev(boxes_a, boxes_b): + """Calculate boxes IoU in the Bird's Eye View. + + Args: + boxes_a (torch.Tensor): Input boxes a with shape (M, 5). + boxes_b (torch.Tensor): Input boxes b with shape (N, 5). + + Returns: + ans_iou (torch.Tensor): IoU result with shape (M, N). + """ + ans_iou = boxes_a.new_zeros( + torch.Size((boxes_a.shape[0], boxes_b.shape[0]))) + + ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(), + boxes_b.contiguous(), ans_iou) + + return ans_iou + + +def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): + """NMS function GPU implementation (for BEV boxes). The overlap of two + boxes for IoU calculation is defined as the exact overlapping area of the + two boxes. In this function, one can also set ``pre_max_size`` and + ``post_max_size``. + + Args: + boxes (torch.Tensor): Input boxes with the shape of [N, 5] + ([x1, y1, x2, y2, ry]). + scores (torch.Tensor): Scores of boxes with the shape of [N]. + thresh (float): Overlap threshold of NMS. + pre_max_size (int, optional): Max size of boxes before NMS. + Default: None. + post_max_size (int, optional): Max size of boxes after NMS. + Default: None. + + Returns: + torch.Tensor: Indexes after NMS. + """ + order = scores.sort(0, descending=True)[1] + + if pre_max_size is not None: + order = order[:pre_max_size] + boxes = boxes[order].contiguous() + + keep = torch.zeros(boxes.size(0), dtype=torch.long) + num_out = ext_module.iou3d_nms_forward(boxes, keep, thresh) + keep = order[keep[:num_out].cuda(boxes.device)].contiguous() + if post_max_size is not None: + keep = keep[:post_max_size] + return keep + + +def nms_normal_bev(boxes, scores, thresh): + """Normal NMS function GPU implementation (for BEV boxes). The overlap of + two boxes for IoU calculation is defined as the exact overlapping area of + the two boxes WITH their yaw angle set to 0. + + Args: + boxes (torch.Tensor): Input boxes with shape (N, 5). + scores (torch.Tensor): Scores of predicted boxes with shape (N). + thresh (float): Overlap threshold of NMS. + + Returns: + torch.Tensor: Remaining indices with scores in descending order. + """ + order = scores.sort(0, descending=True)[1] + + boxes = boxes[order].contiguous() + + keep = torch.zeros(boxes.size(0), dtype=torch.long) + num_out = ext_module.iou3d_nms_normal_forward(boxes, keep, thresh) + return order[keep[:num_out].cuda(boxes.device)].contiguous() diff --git a/tests/test_ops/test_iou3d.py b/tests/test_ops/test_iou3d.py new file mode 100644 index 0000000000..9747e131f0 --- /dev/null +++ b/tests/test_ops/test_iou3d.py @@ -0,0 +1,58 @@ +import numpy as np +import pytest +import torch + +from mmcv.ops import boxes_iou_bev, nms_bev, nms_normal_bev + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_boxes_iou_bev(): + np_boxes1 = np.asarray( + [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], + [7.0, 7.0, 8.0, 8.0, 0.4]], + dtype=np.float32) + np_boxes2 = np.asarray( + [[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5], + [5.0, 5.0, 6.0, 7.0, 0.4]], + dtype=np.float32) + np_expect_ious = np.asarray( + [[0.2621, 0.2948, 0.0000], [0.0549, 0.1587, 0.0000], + [0.0000, 0.0000, 0.0000]], + dtype=np.float32) + + boxes1 = torch.from_numpy(np_boxes1).cuda() + boxes2 = torch.from_numpy(np_boxes2).cuda() + + ious = boxes_iou_bev(boxes1, boxes2) + assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_nms_gpu(): + np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0], + [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]], + dtype=np.float32) + np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) + np_inds = np.array([1, 0, 3]) + boxes = torch.from_numpy(np_boxes) + scores = torch.from_numpy(np_scores) + inds = nms_bev(boxes.cuda(), scores.cuda(), thresh=0.3) + + assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_nms_normal_gpu(): + np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0], + [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]], + dtype=np.float32) + np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) + np_inds = np.array([1, 2, 0, 3]) + boxes = torch.from_numpy(np_boxes) + scores = torch.from_numpy(np_scores) + inds = nms_normal_bev(boxes.cuda(), scores.cuda(), thresh=0.3) + + assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu From 48c7b57688228d5fbf316bf3c6a5add1b6453199 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Fri, 22 Oct 2021 21:18:19 +0800 Subject: [PATCH 03/30] [Fix] Update test data for test_iou3d (#1427) * Update test data for test_iou3d * delete blank lines Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmcv/ops/csrc/pytorch/iou3d.cpp | 58 +++++++++++++-------------------- mmcv/ops/iou3d.py | 2 ++ tests/test_ops/test_iou3d.py | 24 +++++++------- 3 files changed, 37 insertions(+), 47 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/iou3d.cpp b/mmcv/ops/csrc/pytorch/iou3d.cpp index eecfdf224a..46051c4d66 100644 --- a/mmcv/ops/csrc/pytorch/iou3d.cpp +++ b/mmcv/ops/csrc/pytorch/iou3d.cpp @@ -134,25 +134,18 @@ int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh) { const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); - unsigned long long *mask_data = NULL; - CHECK_ERROR( - cudaMalloc((void **)&mask_data, - boxes_num * col_blocks * sizeof(unsigned long long))); + Tensor mask = + at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong)); + unsigned long long *mask_data = + (unsigned long long *)mask.data_ptr(); iou3d_nms_forward_cuda(boxes, mask_data, boxes_num, nms_overlap_thresh); - // unsigned long long mask_cpu[boxes_num * col_blocks]; - // unsigned long long *mask_cpu = new unsigned long long [boxes_num * - // col_blocks]; - std::vector mask_cpu(boxes_num * col_blocks); + at::Tensor mask_cpu = mask.to(at::kCPU); + unsigned long long *mask_host = + (unsigned long long *)mask_cpu.data_ptr(); - // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); - CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, - boxes_num * col_blocks * sizeof(unsigned long long), - cudaMemcpyDeviceToHost)); - - cudaFree(mask_data); - - unsigned long long *remv_cpu = new unsigned long long[col_blocks](); + std::vector remv_cpu(col_blocks); + memset(&remv_cpu[0], 0, sizeof(unsigned long long) * col_blocks); int num_to_keep = 0; @@ -162,13 +155,13 @@ int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh) { if (!(remv_cpu[nblock] & (1ULL << inblock))) { keep_data[num_to_keep++] = i; - unsigned long long *p = &mask_cpu[0] + i * col_blocks; + unsigned long long *p = &mask_host[0] + i * col_blocks; for (int j = nblock; j < col_blocks; j++) { remv_cpu[j] |= p[j]; } } } - delete[] remv_cpu; + if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); return num_to_keep; @@ -196,26 +189,19 @@ int iou3d_nms_normal_forward(Tensor boxes, Tensor keep, const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); - unsigned long long *mask_data = NULL; - CHECK_ERROR( - cudaMalloc((void **)&mask_data, - boxes_num * col_blocks * sizeof(unsigned long long))); + Tensor mask = + at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong)); + unsigned long long *mask_data = + (unsigned long long *)mask.data_ptr(); iou3d_nms_normal_forward_cuda(boxes, mask_data, boxes_num, nms_overlap_thresh); - // unsigned long long mask_cpu[boxes_num * col_blocks]; - // unsigned long long *mask_cpu = new unsigned long long [boxes_num * - // col_blocks]; - std::vector mask_cpu(boxes_num * col_blocks); - - CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, - boxes_num * col_blocks * sizeof(unsigned long long), - cudaMemcpyDeviceToHost)); - - cudaFree(mask_data); - - unsigned long long *remv_cpu = new unsigned long long[col_blocks](); + at::Tensor mask_cpu = mask.to(at::kCPU); + unsigned long long *mask_host = + (unsigned long long *)mask_cpu.data_ptr(); + std::vector remv_cpu(col_blocks); + memset(&remv_cpu[0], 0, sizeof(unsigned long long) * col_blocks); int num_to_keep = 0; for (int i = 0; i < boxes_num; i++) { @@ -224,13 +210,13 @@ int iou3d_nms_normal_forward(Tensor boxes, Tensor keep, if (!(remv_cpu[nblock] & (1ULL << inblock))) { keep_data[num_to_keep++] = i; - unsigned long long *p = &mask_cpu[0] + i * col_blocks; + unsigned long long *p = &mask_host[0] + i * col_blocks; for (int j = nblock; j < col_blocks; j++) { remv_cpu[j] |= p[j]; } } } - delete[] remv_cpu; + if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); return num_to_keep; diff --git a/mmcv/ops/iou3d.py b/mmcv/ops/iou3d.py index f22a9c82c0..6fc7197919 100644 --- a/mmcv/ops/iou3d.py +++ b/mmcv/ops/iou3d.py @@ -47,6 +47,7 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): Returns: torch.Tensor: Indexes after NMS. """ + assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]' order = scores.sort(0, descending=True)[1] if pre_max_size is not None: @@ -74,6 +75,7 @@ def nms_normal_bev(boxes, scores, thresh): Returns: torch.Tensor: Remaining indices with scores in descending order. """ + assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]' order = scores.sort(0, descending=True)[1] boxes = boxes[order].contiguous() diff --git a/tests/test_ops/test_iou3d.py b/tests/test_ops/test_iou3d.py index 9747e131f0..21ed84a9e5 100644 --- a/tests/test_ops/test_iou3d.py +++ b/tests/test_ops/test_iou3d.py @@ -30,29 +30,31 @@ def test_boxes_iou_bev(): @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') -def test_nms_gpu(): - np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0], - [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]], - dtype=np.float32) +def test_nms_bev(): + np_boxes = np.array( + [[6.0, 3.0, 8.0, 7.0, 2.0], [3.0, 6.0, 9.0, 11.0, 1.0], + [3.0, 7.0, 10.0, 12.0, 1.0], [1.0, 4.0, 13.0, 7.0, 3.0]], + dtype=np.float32) np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) np_inds = np.array([1, 0, 3]) boxes = torch.from_numpy(np_boxes) scores = torch.from_numpy(np_scores) inds = nms_bev(boxes.cuda(), scores.cuda(), thresh=0.3) - assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu + assert np.allclose(inds.cpu().numpy(), np_inds) @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') -def test_nms_normal_gpu(): - np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0], - [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]], - dtype=np.float32) +def test_nms_normal_bev(): + np_boxes = np.array( + [[6.0, 3.0, 8.0, 7.0, 2.0], [3.0, 6.0, 9.0, 11.0, 1.0], + [3.0, 7.0, 10.0, 12.0, 1.0], [1.0, 4.0, 13.0, 7.0, 3.0]], + dtype=np.float32) np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) - np_inds = np.array([1, 2, 0, 3]) + np_inds = np.array([1, 0, 3]) boxes = torch.from_numpy(np_boxes) scores = torch.from_numpy(np_scores) inds = nms_normal_bev(boxes.cuda(), scores.cuda(), thresh=0.3) - assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu + assert np.allclose(inds.cpu().numpy(), np_inds) From 243e6c726ae4081fa53f0a32160d6ae4ca167001 Mon Sep 17 00:00:00 2001 From: dingchang Date: Sat, 23 Oct 2021 14:01:31 +0800 Subject: [PATCH 04/30] [Feature] Add group points ops from mmdet3d (#1415) * add op (group points) and its related ops (ball query and knn) in mmdet3d * refactor code * fix typo * refactor code * fix typo * refactor code * make input contiguous Co-authored-by: zhouzaida --- docs/understand_mmcv/ops.md | 1 + mmcv/ops/__init__.py | 13 +- .../common/cuda/group_points_cuda_kernel.cuh | 63 +++++ .../csrc/pytorch/cuda/group_points_cuda.cu | 61 +++++ mmcv/ops/csrc/pytorch/group_points.cpp | 58 +++++ mmcv/ops/csrc/pytorch/pybind.cpp | 20 ++ mmcv/ops/group_points.py | 224 ++++++++++++++++++ tests/test_ops/test_group_points.py | 76 ++++++ 8 files changed, 510 insertions(+), 6 deletions(-) create mode 100644 mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/group_points.cpp create mode 100644 mmcv/ops/group_points.py create mode 100644 tests/test_ops/test_group_points.py diff --git a/docs/understand_mmcv/ops.md b/docs/understand_mmcv/ops.md index 900705afa0..2729e441c1 100644 --- a/docs/understand_mmcv/ops.md +++ b/docs/understand_mmcv/ops.md @@ -16,6 +16,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - FurthestPointSample - FurthestPointSampleWithDist - GeneralizedAttention +- GroupPoints - KNN - MaskedConv - NMS diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index b5a06c7614..999e090a45 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -22,6 +22,7 @@ furthest_point_sample_with_dist) from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu from .gather_points import gather_points +from .group_points import GroupAll, QueryAndGroup, grouping_operation from .info import (get_compiler_version, get_compiling_cuda_version, get_onnxruntime_op_path) from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev @@ -68,13 +69,13 @@ 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', - 'upfirdn2d', 'FusedBiasLeakyReLU', 'boxes_iou_bev', 'nms_bev', - 'nms_normal_bev', 'fused_bias_leakyrelu', 'RoIAlignRotated', - 'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn', + 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', + 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup', + 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn', 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample', 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation', - 'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter', - 'RoIAwarePool3d', 'points_in_boxes_part', 'points_in_boxes_cpu', - 'points_in_boxes_all' + 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization', + 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', + 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all' ] diff --git a/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh new file mode 100644 index 0000000000..9cfc2dc865 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh @@ -0,0 +1,63 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu +#ifndef GROUP_POINTS_CUDA_KERNEL_CUH +#define GROUP_POINTS_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__global__ void group_points_forward_cuda_kernel(int b, int c, int n, + int npoints, int nsample, + const T *points, + const int *__restrict__ idx, + T *out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + int in_idx = bs_idx * c * n + c_idx * n + idx[0]; + int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + + pt_idx * nsample + sample_idx; + + out[out_idx] = points[in_idx]; +} + +template +__global__ void group_points_backward_cuda_kernel(int b, int c, int n, + int npoints, int nsample, + const T *grad_out, + const int *__restrict__ idx, + T *grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + + pt_idx * nsample + sample_idx; + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + + atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0], grad_out[0]); +} + +#endif // GROUP_POINTS_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu new file mode 100644 index 0000000000..e7c57b018a --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu @@ -0,0 +1,61 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu +#include +#include + +#include "group_points_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor points, + const Tensor idx, Tensor out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "group_points_forward_cuda_kernel", [&] { + group_points_forward_cuda_kernel + <<>>( + b, c, n, npoints, nsample, points.data_ptr(), + idx.data_ptr(), out.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor grad_out, + const Tensor idx, + Tensor grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + + at::cuda::CUDAGuard device_guard(grad_out.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_out.scalar_type(), "group_points_backward_cuda_kernel", [&] { + group_points_backward_cuda_kernel + <<>>( + b, c, n, npoints, nsample, grad_out.data_ptr(), + idx.data_ptr(), grad_points.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/group_points.cpp b/mmcv/ops/csrc/pytorch/group_points.cpp new file mode 100644 index 0000000000..1ebc947a19 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/group_points.cpp @@ -0,0 +1,58 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor points, + const Tensor idx, Tensor out); +void group_points_forward_cuda(int b, int c, int n, int npoints, int nsample, + const Tensor points, const Tensor idx, + Tensor out) { + GroupPointsForwardCUDAKernelLauncher(b, c, n, npoints, nsample, points, idx, + out); +}; + +void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor grad_out, + const Tensor idx, + Tensor grad_points); +void group_points_backward_cuda(int b, int c, int n, int npoints, int nsample, + const Tensor grad_out, const Tensor idx, + Tensor grad_points) { + GroupPointsBackwardCUDAKernelLauncher(b, c, n, npoints, nsample, grad_out, + idx, grad_points); +}; +#endif + +void group_points_forward(int b, int c, int n, int npoints, int nsample, + Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + group_points_forward_cuda(b, c, n, npoints, nsample, points_tensor, + idx_tensor, out_tensor); +#else + AT_ERROR("group_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("group_points is not implemented on CPU"); + } +} + +void group_points_backward(int b, int c, int n, int npoints, int nsample, + Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor) { + if (grad_out_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + group_points_backward_cuda(b, c, n, npoints, nsample, grad_out_tensor, + idx_tensor, grad_points_tensor); +#else + AT_ERROR("group_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("group_points is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 7b39a5e443..8f52e26e82 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -65,6 +65,14 @@ void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois, int pooled_width, float spatial_scale, int sampling_ratio, float gamma); +void group_points_forward(int b, int c, int n, int npoints, int nsample, + Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor); + +void group_points_backward(int b, int c, int n, int npoints, int nsample, + Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor); + void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag); @@ -453,6 +461,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), py::arg("aligned"), py::arg("offset")); + m.def("group_points_forward", &group_points_forward, "group_points_forward", + py::arg("b"), py::arg("c"), py::arg("n"), py::arg("npoints"), + py::arg("nsample"), py::arg("points_tensor"), py::arg("idx_tensor"), + py::arg("out_tensor")); + m.def("group_points_backward", &group_points_backward, + "group_points_backward", py::arg("b"), py::arg("c"), py::arg("n"), + py::arg("npoints"), py::arg("nsample"), py::arg("grad_out_tensor"), + py::arg("idx_tensor"), py::arg("grad_points_tensor")); + m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"), + py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"), + py::arg("new_xyz_tensor"), py::arg("idx_tensor"), + py::arg("dist2_tensor")); m.def("iou3d_boxes_overlap_bev_forward", &iou3d_boxes_overlap_bev_forward, "iou3d_boxes_overlap_bev_forward", py::arg("boxes_a"), py::arg("boxes_b"), py::arg("ans_overlap")); diff --git a/mmcv/ops/group_points.py b/mmcv/ops/group_points.py new file mode 100644 index 0000000000..5afd227944 --- /dev/null +++ b/mmcv/ops/group_points.py @@ -0,0 +1,224 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from torch import nn as nn +from torch.autograd import Function + +from ..utils import ext_loader +from .ball_query import ball_query +from .knn import knn + +ext_module = ext_loader.load_ext( + '_ext', ['group_points_forward', 'group_points_backward']) + + +class QueryAndGroup(nn.Module): + """Groups points with a ball query of radius. + + Args: + max_radius (float): The maximum radius of the balls. + If None is given, we will use kNN sampling instead of ball query. + sample_num (int): Maximum number of features to gather in the ball. + min_radius (float, optional): The minimum radius of the balls. + Default: 0. + use_xyz (bool, optional): Whether to use xyz. + Default: True. + return_grouped_xyz (bool, optional): Whether to return grouped xyz. + Default: False. + normalize_xyz (bool, optional): Whether to normalize xyz. + Default: False. + uniform_sample (bool, optional): Whether to sample uniformly. + Default: False + return_unique_cnt (bool, optional): Whether to return the count of + unique samples. Default: False. + return_grouped_idx (bool, optional): Whether to return grouped idx. + Default: False. + """ + + def __init__(self, + max_radius, + sample_num, + min_radius=0, + use_xyz=True, + return_grouped_xyz=False, + normalize_xyz=False, + uniform_sample=False, + return_unique_cnt=False, + return_grouped_idx=False): + super().__init__() + self.max_radius = max_radius + self.min_radius = min_radius + self.sample_num = sample_num + self.use_xyz = use_xyz + self.return_grouped_xyz = return_grouped_xyz + self.normalize_xyz = normalize_xyz + self.uniform_sample = uniform_sample + self.return_unique_cnt = return_unique_cnt + self.return_grouped_idx = return_grouped_idx + if self.return_unique_cnt: + assert self.uniform_sample, \ + 'uniform_sample should be True when ' \ + 'returning the count of unique samples' + if self.max_radius is None: + assert not self.normalize_xyz, \ + 'can not normalize grouped xyz when max_radius is None' + + def forward(self, points_xyz, center_xyz, features=None): + """ + Args: + points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. + center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods. + features (Tensor): (B, C, N) Descriptors of the features. + + Return: + Tensor: (B, 3 + C, npoint, sample_num) Grouped feature. + """ + # if self.max_radius is None, we will perform kNN instead of ball query + # idx is of shape [B, npoint, sample_num] + if self.max_radius is None: + idx = knn(self.sample_num, points_xyz, center_xyz, False) + idx = idx.transpose(1, 2).contiguous() + else: + idx = ball_query(self.min_radius, self.max_radius, self.sample_num, + points_xyz, center_xyz) + + if self.uniform_sample: + unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) + for i_batch in range(idx.shape[0]): + for i_region in range(idx.shape[1]): + unique_ind = torch.unique(idx[i_batch, i_region, :]) + num_unique = unique_ind.shape[0] + unique_cnt[i_batch, i_region] = num_unique + sample_ind = torch.randint( + 0, + num_unique, (self.sample_num - num_unique, ), + dtype=torch.long) + all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) + idx[i_batch, i_region, :] = all_ind + + xyz_trans = points_xyz.transpose(1, 2).contiguous() + # (B, 3, npoint, sample_num) + grouped_xyz = grouping_operation(xyz_trans, idx) + grouped_xyz_diff = grouped_xyz - \ + center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets + if self.normalize_xyz: + grouped_xyz_diff /= self.max_radius + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + # (B, C + 3, npoint, sample_num) + new_features = torch.cat([grouped_xyz_diff, grouped_features], + dim=1) + else: + new_features = grouped_features + else: + assert (self.use_xyz + ), 'Cannot have not features and not use xyz as a feature!' + new_features = grouped_xyz_diff + + ret = [new_features] + if self.return_grouped_xyz: + ret.append(grouped_xyz) + if self.return_unique_cnt: + ret.append(unique_cnt) + if self.return_grouped_idx: + ret.append(idx) + if len(ret) == 1: + return ret[0] + else: + return tuple(ret) + + +class GroupAll(nn.Module): + """Group xyz with feature. + + Args: + use_xyz (bool): Whether to use xyz. + """ + + def __init__(self, use_xyz: bool = True): + super().__init__() + self.use_xyz = use_xyz + + def forward(self, + xyz: torch.Tensor, + new_xyz: torch.Tensor, + features: torch.Tensor = None): + """ + Args: + xyz (Tensor): (B, N, 3) xyz coordinates of the features. + new_xyz (Tensor): new xyz coordinates of the features. + features (Tensor): (B, C, N) features to group. + + Return: + Tensor: (B, C + 3, 1, N) Grouped feature. + """ + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + # (B, 3 + C, 1, N) + new_features = torch.cat([grouped_xyz, grouped_features], + dim=1) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features + + +class GroupingOperation(Function): + """Group feature with given index.""" + + @staticmethod + def forward(ctx, features: torch.Tensor, + indices: torch.Tensor) -> torch.Tensor: + """ + Args: + features (Tensor): (B, C, N) tensor of features to group. + indices (Tensor): (B, npoint, nsample) the indices of + features to group with. + + Returns: + Tensor: (B, C, npoint, nsample) Grouped features. + """ + features = features.contiguous() + indices = indices.contiguous() + + B, nfeatures, nsample = indices.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + + ext_module.group_points_forward(B, C, N, nfeatures, nsample, features, + indices, output) + + ctx.for_backwards = (indices, N) + return output + + @staticmethod + def backward(ctx, + grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients + of the output from forward. + + Returns: + Tensor: (B, C, N) gradient of the features. + """ + idx, N = ctx.for_backwards + + B, C, npoint, nsample = grad_out.size() + grad_features = torch.cuda.FloatTensor(B, C, N).zero_() + + grad_out_data = grad_out.data.contiguous() + ext_module.group_points_backward(B, C, N, npoint, nsample, + grad_out_data, idx, + grad_features.data) + return grad_features, None + + +grouping_operation = GroupingOperation.apply diff --git a/tests/test_ops/test_group_points.py b/tests/test_ops/test_group_points.py new file mode 100644 index 0000000000..1b495c2850 --- /dev/null +++ b/tests/test_ops/test_group_points.py @@ -0,0 +1,76 @@ +import pytest +import torch + +from mmcv.ops import grouping_operation + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_grouping_points(): + idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0], + [0, 0, 0]]]).int().cuda() + festures = torch.tensor([[[ + 0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274, + 0.9268, 0.8414 + ], + [ + 5.4247, 1.5113, 2.3944, 1.4740, 5.0300, + 5.1030, 1.9360, 2.1939, 2.1581, 3.4666 + ], + [ + -1.6266, -1.0281, -1.0393, -1.6931, -1.3982, + -0.5732, -1.0830, -1.7561, -1.6786, -1.6967 + ]], + [[ + -0.0380, -0.1880, -1.5724, 0.6905, -0.3190, + 0.7798, -0.3693, -0.9457, -0.2942, -1.8527 + ], + [ + 1.1773, 1.5009, 2.6399, 5.9242, 1.0962, + 2.7346, 6.0865, 1.5555, 4.3303, 2.8229 + ], + [ + -0.6646, -0.6870, -0.1125, -0.2224, -0.3445, + -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 + ]]]).cuda() + + output = grouping_operation(festures, idx) + expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798], + [-1.3311, -1.3311, -1.3311], + [0.9268, 0.9268, 0.9268], + [0.5798, 0.5798, 0.5798], + [0.5798, 0.5798, 0.5798], + [0.5798, 0.5798, 0.5798]], + [[5.4247, 5.4247, 5.4247], + [1.4740, 1.4740, 1.4740], + [2.1581, 2.1581, 2.1581], + [5.4247, 5.4247, 5.4247], + [5.4247, 5.4247, 5.4247], + [5.4247, 5.4247, 5.4247]], + [[-1.6266, -1.6266, -1.6266], + [-1.6931, -1.6931, -1.6931], + [-1.6786, -1.6786, -1.6786], + [-1.6266, -1.6266, -1.6266], + [-1.6266, -1.6266, -1.6266], + [-1.6266, -1.6266, -1.6266]]], + [[[-0.0380, -0.0380, -0.0380], + [-0.3693, -0.3693, -0.3693], + [-1.8527, -1.8527, -1.8527], + [-0.0380, -0.0380, -0.0380], + [-0.0380, -0.0380, -0.0380], + [-0.0380, -0.0380, -0.0380]], + [[1.1773, 1.1773, 1.1773], + [6.0865, 6.0865, 6.0865], + [2.8229, 2.8229, 2.8229], + [1.1773, 1.1773, 1.1773], + [1.1773, 1.1773, 1.1773], + [1.1773, 1.1773, 1.1773]], + [[-0.6646, -0.6646, -0.6646], + [0.4990, 0.4990, 0.4990], + [0.0386, 0.0386, 0.0386], + [-0.6646, -0.6646, -0.6646], + [-0.6646, -0.6646, -0.6646], + [-0.6646, -0.6646, -0.6646]]]]).cuda() + assert torch.allclose(output, expected_output) From 9b3cffd0277f76bcc3b8e44d65499650b4147d51 Mon Sep 17 00:00:00 2001 From: pc Date: Sat, 23 Oct 2021 14:18:56 +0800 Subject: [PATCH 05/30] add mmdet3d op (#1425) Co-authored-by: zhouzaida --- mmcv/ops/assign_score_withk.py | 44 +++-- mmcv/ops/ball_query.py | 15 +- mmcv/ops/correlation.py | 43 ++++- mmcv/ops/csrc/parrots/assign_score_withk.cpp | 85 +++++++++ .../parrots/assign_score_withk_parrots.cpp | 89 +++++++++ .../csrc/parrots/assign_score_withk_pytorch.h | 19 ++ mmcv/ops/csrc/parrots/ball_query._parrots.cpp | 43 +++++ mmcv/ops/csrc/parrots/ball_query.cpp | 37 ++++ mmcv/ops/csrc/parrots/ball_query_pytorch.h | 11 ++ mmcv/ops/csrc/parrots/correlation.cpp | 87 +++++++++ mmcv/ops/csrc/parrots/correlation_parrots.cpp | 176 ++++++++++++++++++ mmcv/ops/csrc/parrots/correlation_pytorch.h | 18 ++ .../csrc/parrots/furthest_point_sample.cpp | 62 ++++++ .../parrots/furthest_point_sample_parrots.cpp | 57 ++++++ .../parrots/furthest_point_sample_pytorch.h | 14 ++ mmcv/ops/csrc/parrots/gather_points.cpp | 55 ++++++ .../csrc/parrots/gather_points_parrots.cpp | 71 +++++++ mmcv/ops/csrc/parrots/gather_points_pytorch.h | 13 ++ mmcv/ops/csrc/parrots/knn.cpp | 32 ++++ mmcv/ops/csrc/parrots/knn_parrots.cpp | 41 ++++ mmcv/ops/csrc/parrots/knn_pytorch.h | 9 + mmcv/ops/csrc/parrots/roipoint_pool3d.cpp | 60 ++++++ .../csrc/parrots/roipoint_pool3d_parrots.cpp | 31 +++ .../csrc/parrots/roipoint_pool3d_pytorch.h | 10 + mmcv/ops/csrc/parrots/three_interpolate.cpp | 61 ++++++ .../parrots/three_interpolate_parrots.cpp | 74 ++++++++ .../csrc/parrots/three_interpolate_pytorch.h | 14 ++ mmcv/ops/csrc/parrots/three_nn.cpp | 30 +++ mmcv/ops/csrc/parrots/three_nn_parrots.cpp | 35 ++++ mmcv/ops/csrc/parrots/three_nn_pytorch.h | 10 + mmcv/ops/csrc/pytorch/assign_score_withk.cpp | 20 +- mmcv/ops/csrc/pytorch/ball_query.cpp | 6 +- .../csrc/pytorch/furthest_point_sample.cpp | 10 +- mmcv/ops/csrc/pytorch/gather_points.cpp | 12 +- mmcv/ops/csrc/pytorch/knn.cpp | 5 +- mmcv/ops/csrc/pytorch/pybind.cpp | 160 ++++++++-------- mmcv/ops/csrc/pytorch/three_interpolate.cpp | 13 +- mmcv/ops/csrc/pytorch/three_nn.cpp | 6 +- mmcv/ops/furthest_point_sample.py | 18 +- mmcv/ops/gather_points.py | 17 +- mmcv/ops/knn.py | 6 +- mmcv/ops/three_interpolate.py | 8 +- mmcv/ops/three_nn.py | 6 +- 43 files changed, 1474 insertions(+), 159 deletions(-) create mode 100644 mmcv/ops/csrc/parrots/assign_score_withk.cpp create mode 100644 mmcv/ops/csrc/parrots/assign_score_withk_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/assign_score_withk_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/ball_query._parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/ball_query.cpp create mode 100644 mmcv/ops/csrc/parrots/ball_query_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/correlation.cpp create mode 100644 mmcv/ops/csrc/parrots/correlation_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/correlation_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/furthest_point_sample.cpp create mode 100644 mmcv/ops/csrc/parrots/furthest_point_sample_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/furthest_point_sample_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/gather_points.cpp create mode 100644 mmcv/ops/csrc/parrots/gather_points_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/gather_points_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/knn.cpp create mode 100644 mmcv/ops/csrc/parrots/knn_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/knn_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/roipoint_pool3d.cpp create mode 100644 mmcv/ops/csrc/parrots/roipoint_pool3d_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/roipoint_pool3d_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/three_interpolate.cpp create mode 100644 mmcv/ops/csrc/parrots/three_interpolate_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/three_interpolate_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/three_nn.cpp create mode 100644 mmcv/ops/csrc/parrots/three_nn_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/three_nn_pytorch.h diff --git a/mmcv/ops/assign_score_withk.py b/mmcv/ops/assign_score_withk.py index 6cca1cb36e..4906adaa2c 100644 --- a/mmcv/ops/assign_score_withk.py +++ b/mmcv/ops/assign_score_withk.py @@ -57,12 +57,19 @@ def forward(ctx, _, npoint, K, _ = scores.size() output = point_features.new_zeros((B, out_dim, npoint, K)) - ext_module.assign_score_withk_forward(B, N, npoint, M, K, out_dim, - agg[aggregate], - point_features.contiguous(), - center_features.contiguous(), - scores.contiguous(), - knn_idx.contiguous(), output) + ext_module.assign_score_withk_forward( + point_features.contiguous(), + center_features.contiguous(), + scores.contiguous(), + knn_idx.contiguous(), + output, + B=B, + N0=N, + N1=npoint, + M=M, + K=K, + O=out_dim, + aggregate=agg[aggregate]) ctx.save_for_backward(output, point_features, center_features, scores, knn_idx) @@ -92,15 +99,22 @@ def backward(ctx, grad_out): grad_center_features = center_features.new_zeros(center_features.shape) grad_scores = scores.new_zeros(scores.shape) - ext_module.assign_score_withk_backward(B, N, npoint, M, K, out_dim, - agg, grad_out.contiguous(), - point_features.contiguous(), - center_features.contiguous(), - scores.contiguous(), - knn_idx.contiguous(), - grad_point_features, - grad_center_features, - grad_scores) + ext_module.assign_score_withk_backward( + grad_out.contiguous(), + point_features.contiguous(), + center_features.contiguous(), + scores.contiguous(), + knn_idx.contiguous(), + grad_point_features, + grad_center_features, + grad_scores, + B=B, + N0=N, + N1=npoint, + M=M, + K=K, + O=out_dim, + aggregate=agg) return grad_scores, grad_point_features, \ grad_center_features, None, None diff --git a/mmcv/ops/ball_query.py b/mmcv/ops/ball_query.py index f77bdc8bfa..d0466847c6 100644 --- a/mmcv/ops/ball_query.py +++ b/mmcv/ops/ball_query.py @@ -33,9 +33,18 @@ def forward(ctx, min_radius: float, max_radius: float, sample_num: int, npoint = center_xyz.size(1) idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int) - ext_module.ball_query_forward(B, N, npoint, min_radius, max_radius, - sample_num, center_xyz, xyz, idx) - ctx.mark_non_differentiable(idx) + ext_module.ball_query_forward( + center_xyz, + xyz, + idx, + b=B, + n=N, + m=npoint, + min_radius=min_radius, + max_radius=max_radius, + nsample=sample_num) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) return idx @staticmethod diff --git a/mmcv/ops/correlation.py b/mmcv/ops/correlation.py index f6ddedacf2..86dab432c7 100644 --- a/mmcv/ops/correlation.py +++ b/mmcv/ops/correlation.py @@ -39,10 +39,22 @@ def forward(ctx, output = input1.new_zeros(output_size) - ext_module.correlation_forward(input1, input2, output, kH, kW, - patch_size, patch_size, padH, padW, - dilationH, dilationW, dilation_patchH, - dilation_patchW, dH, dW) + ext_module.correlation_forward( + input1, + input2, + output, + kH=kH, + kW=kW, + patchH=patch_size, + patchW=patch_size, + padH=padH, + padW=padW, + dilationH=dilationH, + dilationW=dilationW, + dilation_patchH=dilation_patchH, + dilation_patchW=dilation_patchW, + dH=dH, + dW=dW) return output @@ -60,11 +72,24 @@ def backward(ctx, grad_output): grad_input1 = torch.zeros_like(input1) grad_input2 = torch.zeros_like(input2) - ext_module.correlation_backward(grad_output, input1, input2, - grad_input1, grad_input2, kH, kW, - patch_size, patch_size, padH, padW, - dilationH, dilationW, dilation_patchH, - dilation_patchW, dH, dW) + ext_module.correlation_backward( + grad_output, + input1, + input2, + grad_input1, + grad_input2, + kH=kH, + kW=kW, + patchH=patch_size, + patchW=patch_size, + padH=padH, + padW=padW, + dilationH=dilationH, + dilationW=dilationW, + dilation_patchH=dilation_patchH, + dilation_patchW=dilation_patchW, + dH=dH, + dW=dW) return grad_input1, grad_input2, None, None, None, None, None, None @staticmethod diff --git a/mmcv/ops/csrc/parrots/assign_score_withk.cpp b/mmcv/ops/csrc/parrots/assign_score_withk.cpp new file mode 100644 index 0000000000..d35fd24795 --- /dev/null +++ b/mmcv/ops/csrc/parrots/assign_score_withk.cpp @@ -0,0 +1,85 @@ +// Modified from +// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/paconv_lib/src/gpu +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void AssignScoreWithKForwardCUDAKernelLauncher( + int B, int N0, int N1, int M, int K, int O, int aggregate, + const Tensor& points, const Tensor& centers, const Tensor& scores, + const Tensor& knn_idx, Tensor& output); + +void assign_score_withk_forward_cuda(int B, int N0, int N1, int M, int K, int O, + int aggregate, const Tensor& points, + const Tensor& centers, + const Tensor& scores, + const Tensor& knn_idx, Tensor& output) { + AssignScoreWithKForwardCUDAKernelLauncher( + B, N0, N1, M, K, O, aggregate, points, centers, scores, knn_idx, output); +}; + +void AssignScoreWithKBackwardCUDAKernelLauncher( + int B, int N0, int N1, int M, int K, int O, int aggregate, + const Tensor& grad_out, const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores); + +void assign_score_withk_backward_cuda( + int B, int N0, int N1, int M, int K, int O, int aggregate, + const Tensor& grad_out, const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores) { + AssignScoreWithKBackwardCUDAKernelLauncher( + B, N0, N1, M, K, O, aggregate, grad_out, points, centers, scores, knn_idx, + grad_points, grad_centers, grad_scores); +}; +#endif + +void assign_score_withk_forward(const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, + Tensor& output, int B, int N0, int N1, int M, + int K, int O, int aggregate) { + if (points.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(centers); + CHECK_CONTIGUOUS(scores); + CHECK_CONTIGUOUS(knn_idx); + CHECK_CONTIGUOUS(output); + + assign_score_withk_forward_cuda(B, N0, N1, M, K, O, aggregate, points, + centers, scores, knn_idx, output); +#else + AT_ERROR("assign_score_withk is not compiled with GPU support"); +#endif + } else { + AT_ERROR("assign_score_withk is not implemented on CPU"); + } +} + +void assign_score_withk_backward(const Tensor& grad_out, const Tensor& points, + const Tensor& centers, const Tensor& scores, + const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores, + int B, int N0, int N1, int M, int K, int O, + int aggregate) { + if (grad_points.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(scores); + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(centers); + CHECK_CONTIGUOUS(knn_idx); + CHECK_CONTIGUOUS(grad_scores); + CHECK_CONTIGUOUS(grad_points); + CHECK_CONTIGUOUS(grad_centers); + + assign_score_withk_backward_cuda(B, N0, N1, M, K, O, aggregate, grad_out, + points, centers, scores, knn_idx, + grad_points, grad_centers, grad_scores); +#else + AT_ERROR("assign_score_withk is not compiled with GPU support"); +#endif + } else { + AT_ERROR("assign_score_withk is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/assign_score_withk_parrots.cpp b/mmcv/ops/csrc/parrots/assign_score_withk_parrots.cpp new file mode 100644 index 0000000000..5729c71631 --- /dev/null +++ b/mmcv/ops/csrc/parrots/assign_score_withk_parrots.cpp @@ -0,0 +1,89 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "assign_score_withk_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void assign_score_withk_forward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int B, N0, N1, M, K, O, aggregate; + SSAttrs(attr) + .get("B", B) + .get("N0", N0) + .get("N1", N1) + .get("M", M) + .get("K", K) + .get("O", O) + .get("aggregate", aggregate) + .done(); + + const auto& points = buildATensor(ctx, ins[0]); + const auto& centers = buildATensor(ctx, ins[1]); + const auto& scores = buildATensor(ctx, ins[2]); + const auto& knn_idx = buildATensor(ctx, ins[3]); + + auto output = buildATensor(ctx, outs[0]); + assign_score_withk_forward(points, centers, scores, knn_idx, output, B, N0, + N1, M, K, O, aggregate); +} + +void assign_score_withk_backward_cuda_parrots( + CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int B, N0, N1, M, K, O, aggregate; + SSAttrs(attr) + .get("B", B) + .get("N0", N0) + .get("N1", N1) + .get("M", M) + .get("K", K) + .get("O", O) + .get("aggregate", aggregate) + .done(); + + const auto& grad_out = buildATensor(ctx, ins[0]); + const auto& points = buildATensor(ctx, ins[1]); + const auto& centers = buildATensor(ctx, ins[2]); + const auto& scores = buildATensor(ctx, ins[3]); + const auto& knn_idx = buildATensor(ctx, ins[4]); + + auto grad_points = buildATensor(ctx, outs[0]); + auto grad_centers = buildATensor(ctx, outs[1]); + auto grad_scores = buildATensor(ctx, outs[2]); + assign_score_withk_backward(grad_out, points, centers, scores, knn_idx, + grad_points, grad_centers, grad_scores, B, N0, N1, + M, K, O, aggregate); +} + +PARROTS_EXTENSION_REGISTER(assign_score_withk_forward) + .attr("B") + .attr("N0") + .attr("N1") + .attr("M") + .attr("K") + .attr("O") + .attr("aggregate") + .input(4) + .output(1) + .apply(assign_score_withk_forward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(assign_score_withk_backward) + .attr("B") + .attr("N0") + .attr("N1") + .attr("M") + .attr("K") + .attr("O") + .attr("aggregate") + .input(5) + .output(3) + .apply(assign_score_withk_backward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/assign_score_withk_pytorch.h b/mmcv/ops/csrc/parrots/assign_score_withk_pytorch.h new file mode 100644 index 0000000000..660594feec --- /dev/null +++ b/mmcv/ops/csrc/parrots/assign_score_withk_pytorch.h @@ -0,0 +1,19 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ASSIGN_SCORE_WITHK_PYTORCH_H +#define ASSIGN_SCORE_WITHK_PYTORCH_H +#include +using namespace at; + +void assign_score_withk_forward(const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, + Tensor& output, int B, int N0, int N1, int M, + int K, int O, int aggregate); + +void assign_score_withk_backward(const Tensor& grad_out, const Tensor& points, + const Tensor& centers, const Tensor& scores, + const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores, + int B, int N0, int N1, int M, int K, int O, + int aggregate); + +#endif // ASSIGN_SCORE_WITHK_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/ball_query._parrots.cpp b/mmcv/ops/csrc/parrots/ball_query._parrots.cpp new file mode 100644 index 0000000000..01ab9739b0 --- /dev/null +++ b/mmcv/ops/csrc/parrots/ball_query._parrots.cpp @@ -0,0 +1,43 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "ball_query_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void ball_query_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, n, m, nsample; + float min_radius, max_radius; + SSAttrs(attr) + .get("b", b) + .get("n", n) + .get("m", m) + .get("nsample", nsample) + .get("min_radius", min_radius) + .get("max_radius", max_radius) + .done(); + + const auto& center_xyz = buildATensor(ctx, ins[0]); + const auto& xyz = buildATensor(ctx, ins[1]); + auto idx = buildATensor(ctx, outs[0]); + ball_query_forward(center_xyz, xyz, idx, b, n, m, min_radius, max_radius, + nsample); +} + +PARROTS_EXTENSION_REGISTER(ball_query_forward) + .attr("b") + .attr("n") + .attr("m") + .attr("nsample") + .attr("min_radius") + .attr("max_radius") + .input(2) + .output(1) + .apply(ball_query_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/ball_query.cpp b/mmcv/ops/csrc/parrots/ball_query.cpp new file mode 100644 index 0000000000..fc2709f0db --- /dev/null +++ b/mmcv/ops/csrc/parrots/ball_query.cpp @@ -0,0 +1,37 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius, + float max_radius, int nsample, + const Tensor new_xyz, const Tensor xyz, + Tensor idx); + +void ball_query_forward_cuda(int b, int n, int m, float min_radius, + float max_radius, int nsample, + const Tensor new_xyz, const Tensor xyz, + Tensor idx) { + BallQueryForwardCUDAKernelLauncher(b, n, m, min_radius, max_radius, nsample, + new_xyz, xyz, idx); +}; +#endif + +void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, + Tensor idx_tensor, int b, int n, int m, + float min_radius, float max_radius, int nsample) { + if (new_xyz_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(new_xyz_tensor); + CHECK_CUDA_INPUT(xyz_tensor); + + ball_query_forward_cuda(b, n, m, min_radius, max_radius, nsample, + new_xyz_tensor, xyz_tensor, idx_tensor); +#else + AT_ERROR("ball_query is not compiled with GPU support"); +#endif + } else { + AT_ERROR("ball_query is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/ball_query_pytorch.h b/mmcv/ops/csrc/parrots/ball_query_pytorch.h new file mode 100644 index 0000000000..70026f3150 --- /dev/null +++ b/mmcv/ops/csrc/parrots/ball_query_pytorch.h @@ -0,0 +1,11 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef BALL_QUERY_PYTORCH_H +#define BALL_QUERY_PYTORCH_H +#include +using namespace at; + +void ball_query_forward(const Tensor new_xyz, const Tensor xyz, Tensor idx, + int b, int n, int m, float min_radius, float max_radius, + int nsample); + +#endif // BALL_QUERY_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/correlation.cpp b/mmcv/ops/csrc/parrots/correlation.cpp new file mode 100644 index 0000000000..c3614a500b --- /dev/null +++ b/mmcv/ops/csrc/parrots/correlation.cpp @@ -0,0 +1,87 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA + +void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, + Tensor output, int kH, int kW, + int patchH, int patchW, int padH, + int padW, int dilationH, + int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); + +void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1, + Tensor input2, Tensor grad_input1, + Tensor grad_input2, int kH, int kW, + int patchH, int patchW, int padH, + int padW, int dilationH, + int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); + +void correlation_cuda_forward(Tensor input1, Tensor input2, Tensor output, + int kH, int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, int dH, + int dW) { + CorrelationForwardCUDAKernelLauncher( + input1, input2, output, kH, kW, patchH, patchW, padH, padW, dilationH, + dilationW, dilation_patchH, dilation_patchW, dH, dW); +} + +void correlation_cuda_backward(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, int dH, + int dW) { + CorrelationBackwardCUDAKernelLauncher( + grad_output, input1, input2, grad_input1, grad_input2, kH, kW, patchH, + patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); +} + +#endif + +void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + if (input1.device().is_cuda() && input2.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(input1); + CHECK_CUDA_INPUT(input2); + correlation_cuda_forward(input1, input2, output, kH, kW, patchH, patchW, + padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); +#else + AT_ERROR("Correlation is not compiled with GPU support"); +#endif + } else { + AT_ERROR("Correlation is not implemented on CPU"); + } +} + +void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + if (input1.device().is_cuda() && input2.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(grad_output); + CHECK_CUDA_INPUT(input1); + CHECK_CUDA_INPUT(input2); + correlation_cuda_backward(grad_output, input1, input2, grad_input1, + grad_input2, kH, kW, patchH, patchW, padH, padW, + dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); + +#else + AT_ERROR("Correlation is not compiled with GPU support"); +#endif + } else { + AT_ERROR("Correlation is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/correlation_parrots.cpp b/mmcv/ops/csrc/parrots/correlation_parrots.cpp new file mode 100644 index 0000000000..b1e287d063 --- /dev/null +++ b/mmcv/ops/csrc/parrots/correlation_parrots.cpp @@ -0,0 +1,176 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "correlation_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void correlation_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW; + SSAttrs(attr) + .get("kH", kH) + .get("kW", kW) + .get("patchH", patchH) + .get("patchW", patchW) + .get("padH", padH) + .get("padW", padW) + .get("dilationH", dilationH) + .get("dilationW", dilationW) + .get("dilation_patchH", dilation_patchH) + .get("dilation_patchW", dilation_patchW) + .get("dH", dH) + .get("dW", dW) + .done(); + + auto input1 = buildATensor(ctx, ins[0]); + auto input2 = buildATensor(ctx, ins[1]); + + auto output = buildATensor(ctx, outs[0]); + + correlation_forward(input1, input2, output, kH, kW, patchH, patchW, padH, + padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); +} + +void correlation_backward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW; + SSAttrs(attr) + .get("kH", kH) + .get("kW", kW) + .get("patchH", patchH) + .get("patchW", patchW) + .get("padH", padH) + .get("padW", padW) + .get("dilationH", dilationH) + .get("dilationW", dilationW) + .get("dilation_patchH", dilation_patchH) + .get("dilation_patchW", dilation_patchW) + .get("dH", dH) + .get("dW", dW) + .done(); + + auto grad_output = buildATensor(ctx, ins[0]); + auto input1 = buildATensor(ctx, ins[1]); + auto input2 = buildATensor(ctx, ins[2]); + + auto grad_input1 = buildATensor(ctx, outs[0]); + auto grad_input2 = buildATensor(ctx, outs[1]); + + correlation_backward(grad_output, input1, input2, grad_input1, grad_input2, + kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, + dilation_patchH, dilation_patchW, dH, dW); +} +#endif + +void correlation_forward_cpu_parrots(HostContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW; + SSAttrs(attr) + .get("kH", kH) + .get("kW", kW) + .get("patchH", patchH) + .get("patchW", patchW) + .get("padH", padH) + .get("padW", padW) + .get("dilationH", dilationH) + .get("dilationW", dilationW) + .get("dilation_patchH", dilation_patchH) + .get("dilation_patchW", dilation_patchW) + .get("dH", dH) + .get("dW", dW) + .done(); + + auto input1 = buildATensor(ctx, ins[0]); + auto input2 = buildATensor(ctx, ins[1]); + + auto output = buildATensor(ctx, outs[0]); + + correlation_forward(input1, input2, output, kH, kW, patchH, patchW, padH, + padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); +} + +void correlation_backward_cpu_parrots(HostContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW; + SSAttrs(attr) + .get("kH", kH) + .get("kW", kW) + .get("patchH", patchH) + .get("patchW", patchW) + .get("padH", padH) + .get("padW", padW) + .get("dilationH", dilationH) + .get("dilationW", dilationW) + .get("dilation_patchH", dilation_patchH) + .get("dilation_patchW", dilation_patchW) + .get("dH", dH) + .get("dW", dW) + .done(); + + auto grad_output = buildATensor(ctx, ins[0]); + auto input1 = buildATensor(ctx, ins[1]); + auto input2 = buildATensor(ctx, ins[2]); + + auto grad_input1 = buildATensor(ctx, outs[0]); + auto grad_input2 = buildATensor(ctx, outs[1]); + + correlation_backward(grad_output, input1, input2, grad_input1, grad_input2, + kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, + dilation_patchH, dilation_patchW, dH, dW); +} + +PARROTS_EXTENSION_REGISTER(correlation_forward) + .attr("kH") + .attr("kW") + .attr("patchH") + .attr("patchW") + .attr("padH") + .attr("padW") + .attr("dilationH") + .attr("dilationW") + .attr("dilation_patchH") + .attr("dilation_patchW") + .attr("dH") + .attr("dW") + .input(2) + .output(1) + .apply(correlation_forward_cpu_parrots) +#ifdef MMCV_WITH_CUDA + .apply(correlation_forward_cuda_parrots) +#endif + .done(); + +PARROTS_EXTENSION_REGISTER(correlation_backward) + .attr("kH") + .attr("kW") + .attr("patchH") + .attr("patchW") + .attr("padH") + .attr("padW") + .attr("dilationH") + .attr("dilationW") + .attr("dilation_patchH") + .attr("dilation_patchW") + .attr("dH") + .attr("dW") + .input(3) + .output(2) + .apply(correlation_backward_cpu_parrots) +#ifdef MMCV_WITH_CUDA + .apply(correlation_backward_cuda_parrots) +#endif + .done(); diff --git a/mmcv/ops/csrc/parrots/correlation_pytorch.h b/mmcv/ops/csrc/parrots/correlation_pytorch.h new file mode 100644 index 0000000000..806fcaa710 --- /dev/null +++ b/mmcv/ops/csrc/parrots/correlation_pytorch.h @@ -0,0 +1,18 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef CORRELATION_PYTORCH_H +#define CORRELATION_PYTORCH_H +#include +using namespace at; + +void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); + +void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); + +#endif // CORRELATION_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/furthest_point_sample.cpp b/mmcv/ops/csrc/parrots/furthest_point_sample.cpp new file mode 100644 index 0000000000..e3ec99a82c --- /dev/null +++ b/mmcv/ops/csrc/parrots/furthest_point_sample.cpp @@ -0,0 +1,62 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/sampling.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void FurthestPointSamplingForwardCUDAKernelLauncher(int b, int n, int m, + const float *dataset, + float *temp, int *idxs); + +void furthest_point_sampling_forward_cuda(int b, int n, int m, + const float *dataset, float *temp, + int *idxs) { + FurthestPointSamplingForwardCUDAKernelLauncher(b, n, m, dataset, temp, idxs); +} + +void FurthestPointSamplingWithDistForwardCUDAKernelLauncher( + int b, int n, int m, const float *dataset, float *temp, int *idxs); + +void furthest_point_sampling_with_dist_forward_cuda(int b, int n, int m, + const float *dataset, + float *temp, int *idxs) { + FurthestPointSamplingWithDistForwardCUDAKernelLauncher(b, n, m, dataset, temp, + idxs); +} +#endif + +void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor, + Tensor idx_tensor, int b, int n, int m) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + const float *points = points_tensor.data_ptr(); + float *temp = temp_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + furthest_point_sampling_forward_cuda(b, n, m, points, temp, idx); +#else + AT_ERROR("furthest_point_sampling is not compiled with GPU support"); +#endif + } else { + AT_ERROR("furthest_point_sampling is not implemented on CPU"); + } +} + +void furthest_point_sampling_with_dist_forward(Tensor points_tensor, + Tensor temp_tensor, + Tensor idx_tensor, int b, int n, + int m) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + const float *points = points_tensor.data(); + float *temp = temp_tensor.data(); + int *idx = idx_tensor.data(); + + furthest_point_sampling_with_dist_forward_cuda(b, n, m, points, temp, idx); +#else + AT_ERROR( + "furthest_point_sampling_with_dist is not compiled with GPU support"); +#endif + } else { + AT_ERROR("furthest_point_sampling_with_dist is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/furthest_point_sample_parrots.cpp b/mmcv/ops/csrc/parrots/furthest_point_sample_parrots.cpp new file mode 100644 index 0000000000..483bfb2431 --- /dev/null +++ b/mmcv/ops/csrc/parrots/furthest_point_sample_parrots.cpp @@ -0,0 +1,57 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "furthest_point_sample_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void furthest_point_sample_forward_cuda_parrots( + CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, n, m; + SSAttrs(attr).get("b", b).get("n", n).get("m", m).done(); + + auto points_tensor = buildATensor(ctx, ins[0]); + auto temp_tensor = buildATensor(ctx, ins[1]); + + auto idx_tensor = buildATensor(ctx, outs[0]); + + furthest_point_sampling_forward(points_tensor, temp_tensor, idx_tensor, b, n, + m); +} + +void furthest_point_sampling_with_dist_forward_cuda_parrots( + CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, n, m; + SSAttrs(attr).get("b", b).get("n", n).get("m", m).done(); + + auto points_tensor = buildATensor(ctx, ins[0]); + auto temp_tensor = buildATensor(ctx, ins[1]); + + auto idx_tensor = buildATensor(ctx, outs[0]); + + furthest_point_sampling_with_dist_forward(points_tensor, temp_tensor, + idx_tensor, b, n, m); +} +PARROTS_EXTENSION_REGISTER(furthest_point_sampling_forward) + .attr("b") + .attr("n") + .attr("m") + .input(2) + .output(1) + .apply(furthest_point_sample_forward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(furthest_point_sampling_with_dist_forward) + .attr("b") + .attr("n") + .attr("m") + .input(2) + .output(1) + .apply(furthest_point_sampling_with_dist_forward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/furthest_point_sample_pytorch.h b/mmcv/ops/csrc/parrots/furthest_point_sample_pytorch.h new file mode 100644 index 0000000000..0325cd66ed --- /dev/null +++ b/mmcv/ops/csrc/parrots/furthest_point_sample_pytorch.h @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef FURTHEST_POINT_SAMPLE_PYTORCH_H +#define FURTHEST_POINT_SAMPLE_PYTORCH_H +#include +using namespace at; + +void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor, + Tensor idx_tensor, int b, int n, int m); + +void furthest_point_sampling_with_dist_forward(Tensor points_tensor, + Tensor temp_tensor, + Tensor idx_tensor, int b, int n, + int m); +#endif // FURTHEST_POINT_SAMPLE_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/gather_points.cpp b/mmcv/ops/csrc/parrots/gather_points.cpp new file mode 100644 index 0000000000..3ab93b600f --- /dev/null +++ b/mmcv/ops/csrc/parrots/gather_points.cpp @@ -0,0 +1,55 @@ +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void GatherPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints, + const Tensor points, + const Tensor idx, Tensor out); + +void gather_points_forward_cuda(int b, int c, int n, int npoints, + const Tensor points, const Tensor idx, + Tensor out) { + GatherPointsForwardCUDAKernelLauncher(b, c, n, npoints, points, idx, out); +}; + +void GatherPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints, + const Tensor grad_out, + const Tensor idx, + Tensor grad_points); + +void gather_points_backward_cuda(int b, int c, int n, int npoints, + const Tensor grad_out, const Tensor idx, + Tensor grad_points) { + GatherPointsBackwardCUDAKernelLauncher(b, c, n, npoints, grad_out, idx, + grad_points); +}; +#endif + +void gather_points_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor, int b, int c, int n, + int npoints) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + gather_points_forward_cuda(b, c, n, npoints, points_tensor, idx_tensor, + out_tensor); +#else + AT_ERROR("gather_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("gather_points is not implemented on CPU"); + } +} + +void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor, int b, int c, int n, + int npoints) { + if (grad_out_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + gather_points_backward_cuda(b, c, n, npoints, grad_out_tensor, idx_tensor, + grad_points_tensor); +#else + AT_ERROR("gather_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("gather_points is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/gather_points_parrots.cpp b/mmcv/ops/csrc/parrots/gather_points_parrots.cpp new file mode 100644 index 0000000000..1d2d9e1290 --- /dev/null +++ b/mmcv/ops/csrc/parrots/gather_points_parrots.cpp @@ -0,0 +1,71 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "gather_points_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void gather_points_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, c, n, npoints; + SSAttrs(attr) + .get("b", b) + .get("c", c) + .get("n", n) + .get("npoints", npoints) + .done(); + + auto points_tensor = buildATensor(ctx, ins[0]); + auto idx_tensor = buildATensor(ctx, ins[1]); + + auto out_tensor = buildATensor(ctx, outs[0]); + + gather_points_forward(points_tensor, idx_tensor, out_tensor, b, c, n, + npoints); +} + +void gather_points_backward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, c, n, npoints; + SSAttrs(attr) + .get("b", b) + .get("c", c) + .get("n", n) + .get("npoints", npoints) + .done(); + + auto grad_out_tensor = buildATensor(ctx, ins[0]); + auto idx_tensor = buildATensor(ctx, ins[1]); + + auto grad_points_tensor = buildATensor(ctx, outs[0]); + + gather_points_backward(grad_out_tensor, idx_tensor, grad_points_tensor, b, c, + n, npoints); +} + +PARROTS_EXTENSION_REGISTER(gather_points_forward) + .attr("b") + .attr("c") + .attr("n") + .attr("npoints") + .input(2) + .output(1) + .apply(gather_points_forward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(gather_points_backward) + .attr("b") + .attr("c") + .attr("n") + .attr("npoints") + .input(2) + .output(1) + .apply(gather_points_backward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/gather_points_pytorch.h b/mmcv/ops/csrc/parrots/gather_points_pytorch.h new file mode 100644 index 0000000000..1689ae6ad9 --- /dev/null +++ b/mmcv/ops/csrc/parrots/gather_points_pytorch.h @@ -0,0 +1,13 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef GATHER_POINTS_PYTORCH_H +#define GATHER_POINTS_PYTORCH_H +#include +using namespace at; + +void gather_points_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor, int b, int c, int n, int npoints); + +void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor, int b, int c, int n, + int npoints); +#endif // GATHER_POINTS_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/knn.cpp b/mmcv/ops/csrc/parrots/knn.cpp new file mode 100644 index 0000000000..55105eb019 --- /dev/null +++ b/mmcv/ops/csrc/parrots/knn.cpp @@ -0,0 +1,32 @@ +// Modified from +// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample, + const Tensor xyz, const Tensor new_xyz, + Tensor idx, Tensor dist2); + +void knn_forward_cuda(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2) { + KNNForwardCUDAKernelLauncher(b, n, m, nsample, xyz, new_xyz, idx, dist2); +} +#endif + +void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, + Tensor dist2_tensor, int b, int n, int m, int nsample) { + if (new_xyz_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(new_xyz_tensor); + CHECK_CUDA_INPUT(xyz_tensor); + + knn_forward_cuda(b, n, m, nsample, xyz_tensor, new_xyz_tensor, idx_tensor, + dist2_tensor); +#else + AT_ERROR("knn is not compiled with GPU support"); +#endif + } else { + AT_ERROR("knn is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/knn_parrots.cpp b/mmcv/ops/csrc/parrots/knn_parrots.cpp new file mode 100644 index 0000000000..585b84644a --- /dev/null +++ b/mmcv/ops/csrc/parrots/knn_parrots.cpp @@ -0,0 +1,41 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "knn_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void knn_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, n, m, nsample; + SSAttrs(attr) + .get("b", b) + .get("n", n) + .get("m", m) + .get("nsample", nsample) + .done(); + + auto xyz_tensor = buildATensor(ctx, ins[0]); + auto new_xyz_tensor = buildATensor(ctx, ins[1]); + + auto idx_tensor = buildATensor(ctx, outs[0]); + auto dist2_tensor = buildATensor(ctx, outs[1]); + + knn_forward(xyz_tensor, new_xyz_tensor, idx_tensor, dist2_tensor, b, n, m, + nsample); +} + +PARROTS_EXTENSION_REGISTER(knn_forward) + .attr("b") + .attr("n") + .attr("m") + .attr("nsample") + .input(2) + .output(2) + .apply(knn_forward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/knn_pytorch.h b/mmcv/ops/csrc/parrots/knn_pytorch.h new file mode 100644 index 0000000000..b0875f8389 --- /dev/null +++ b/mmcv/ops/csrc/parrots/knn_pytorch.h @@ -0,0 +1,9 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef KNN_PYTORCH_H +#define KNN_PYTORCH_H +#include +using namespace at; + +void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, + Tensor dist2_tensor, int b, int n, int m, int nsample); +#endif // KNN_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/roipoint_pool3d.cpp b/mmcv/ops/csrc/parrots/roipoint_pool3d.cpp new file mode 100644 index 0000000000..e9b5054e70 --- /dev/null +++ b/mmcv/ops/csrc/parrots/roipoint_pool3d.cpp @@ -0,0 +1,60 @@ +/* +Modified from +https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d.cpp +Point cloud feature pooling +Written by Shaoshuai Shi +All Rights Reserved 2018. +*/ + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void RoIPointPool3dForwardCUDAKernelLauncher( + int batch_size, int pts_num, int boxes_num, int feature_in_len, + int sampled_pts_num, const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag); + +void roipoint_pool3d_forward_cuda(int batch_size, int pts_num, int boxes_num, + int feature_in_len, int sampled_pts_num, + const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, + Tensor pooled_features, + Tensor pooled_empty_flag) { + RoIPointPool3dForwardCUDAKernelLauncher( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz, + boxes3d, pts_feature, pooled_features, pooled_empty_flag); +}; +#endif + +void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, + Tensor pooled_features, Tensor pooled_empty_flag) { + // params xyz: (B, N, 3) + // params boxes3d: (B, M, 7) + // params pts_feature: (B, N, C) + // params pooled_features: (B, M, 512, 3+C) + // params pooled_empty_flag: (B, M) + + if (xyz.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(xyz); + CHECK_CUDA_INPUT(boxes3d); + CHECK_CUDA_INPUT(pts_feature); + CHECK_CUDA_INPUT(pooled_features); + CHECK_CUDA_INPUT(pooled_empty_flag); + + int batch_size = xyz.size(0); + int pts_num = xyz.size(1); + int boxes_num = boxes3d.size(1); + int feature_in_len = pts_feature.size(2); + int sampled_pts_num = pooled_features.size(2); + + roipoint_pool3d_forward_cuda(batch_size, pts_num, boxes_num, feature_in_len, + sampled_pts_num, xyz, boxes3d, pts_feature, + pooled_features, pooled_empty_flag); +#else + AT_ERROR("roipoint_pool3d is not compiled with GPU support"); +#endif + } else { + AT_ERROR("roipoint_pool3d is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/roipoint_pool3d_parrots.cpp b/mmcv/ops/csrc/parrots/roipoint_pool3d_parrots.cpp new file mode 100644 index 0000000000..17f549849d --- /dev/null +++ b/mmcv/ops/csrc/parrots/roipoint_pool3d_parrots.cpp @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "roipoint_pool3d_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void roipoint_pool3d_forward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + auto xyz = buildATensor(ctx, ins[0]); + auto boxes3d = buildATensor(ctx, ins[1]); + auto pts_feature = buildATensor(ctx, ins[2]); + + auto pooled_features = buildATensor(ctx, outs[0]); + auto pooled_empty_flag = buildATensor(ctx, outs[1]); + + roipoint_pool3d_forward(xyz, boxes3d, pts_feature, pooled_features, + pooled_empty_flag); +} + +PARROTS_EXTENSION_REGISTER(roipoint_pool3d_forward) + .input(3) + .output(2) + .apply(roipoint_pool3d_forward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/roipoint_pool3d_pytorch.h b/mmcv/ops/csrc/parrots/roipoint_pool3d_pytorch.h new file mode 100644 index 0000000000..e5b61b0d9a --- /dev/null +++ b/mmcv/ops/csrc/parrots/roipoint_pool3d_pytorch.h @@ -0,0 +1,10 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROIPOINT_POOL3D_PYTORCH_H +#define ROIPOINT_POOL3D_PYTORCH_H +#include +using namespace at; + +void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, + Tensor pooled_features, Tensor pooled_empty_flag); + +#endif // ROIPOINT_POOL3D_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/three_interpolate.cpp b/mmcv/ops/csrc/parrots/three_interpolate.cpp new file mode 100644 index 0000000000..dbbcd995d0 --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_interpolate.cpp @@ -0,0 +1,61 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n, + const Tensor points, + const Tensor idx, + const Tensor weight, Tensor out); + +void three_interpolate_forward_cuda(int b, int c, int m, int n, + const Tensor points, const Tensor idx, + const Tensor weight, Tensor out) { + ThreeInterpolateForwardCUDAKernelLauncher(b, c, m, n, points, idx, weight, + out); +}; + +void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m, + const Tensor grad_out, + const Tensor idx, + const Tensor weight, + Tensor grad_points); + +void three_interpolate_backward_cuda(int b, int c, int n, int m, + const Tensor grad_out, const Tensor idx, + const Tensor weight, Tensor grad_points) { + ThreeInterpolateBackwardCUDAKernelLauncher(b, c, n, m, grad_out, idx, weight, + grad_points); +}; +#endif + +void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor out_tensor, int b, + int c, int m, int n) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + three_interpolate_forward_cuda(b, c, m, n, points_tensor, idx_tensor, + weight_tensor, out_tensor); +#else + AT_ERROR("three_interpolate is not compiled with GPU support"); +#endif + } else { + AT_ERROR("three_interpolate is not implemented on CPU"); + } +} + +void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor grad_points_tensor, + int b, int c, int n, int m) { + if (grad_out_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + three_interpolate_backward_cuda(b, c, n, m, grad_out_tensor, idx_tensor, + weight_tensor, grad_points_tensor); +#else + AT_ERROR("three_interpolate is not compiled with GPU support"); +#endif + } else { + AT_ERROR("three_interpolate is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/three_interpolate_parrots.cpp b/mmcv/ops/csrc/parrots/three_interpolate_parrots.cpp new file mode 100644 index 0000000000..a71a90fd1e --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_interpolate_parrots.cpp @@ -0,0 +1,74 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "three_interpolate_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void three_interpolate_forward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, c, m, n; + SSAttrs(attr) + .get("b", b) + .get("c", c) + .get("m", m) + .get("n", n) + .done(); + + auto points_tensor = buildATensor(ctx, ins[0]); + auto idx_tensor = buildATensor(ctx, ins[1]); + auto weight_tensor = buildATensor(ctx, ins[2]); + + auto out_tensor = buildATensor(ctx, outs[0]); + + three_interpolate_forward(points_tensor, idx_tensor, weight_tensor, + out_tensor, b, c, m, n); +} + +void three_interpolate_backward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, c, n, m; + SSAttrs(attr) + .get("b", b) + .get("c", c) + .get("n", n) + .get("m", m) + .done(); + + auto grad_out_tensor = buildATensor(ctx, ins[0]); + auto idx_tensor = buildATensor(ctx, ins[1]); + auto weight_tensor = buildATensor(ctx, ins[2]); + + auto grad_points_tensor = buildATensor(ctx, outs[0]); + + three_interpolate_backward(grad_out_tensor, idx_tensor, weight_tensor, + grad_points_tensor, b, c, n, m); +} + +PARROTS_EXTENSION_REGISTER(three_interpolate_forward) + .attr("b") + .attr("c") + .attr("m") + .attr("n") + .input(3) + .output(1) + .apply(three_interpolate_forward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(three_interpolate_backward) + .attr("b") + .attr("c") + .attr("n") + .attr("m") + .input(3) + .output(1) + .apply(three_interpolate_backward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/three_interpolate_pytorch.h b/mmcv/ops/csrc/parrots/three_interpolate_pytorch.h new file mode 100644 index 0000000000..464c6d9005 --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_interpolate_pytorch.h @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef THREE_INTERPOLATE_PYTORCH_H +#define THREE_INTERPOLATE_PYTORCH_H +#include +using namespace at; + +void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor out_tensor, int b, + int c, int m, int n); + +void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor grad_points_tensor, + int b, int c, int n, int m); +#endif // THREE_INTERPOLATE_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/three_nn.cpp b/mmcv/ops/csrc/parrots/three_nn.cpp new file mode 100644 index 0000000000..158ac00231 --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_nn.cpp @@ -0,0 +1,30 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void ThreeNNForwardCUDAKernelLauncher(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, + Tensor idx); + +void three_nn_forward_cuda(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx) { + ThreeNNForwardCUDAKernelLauncher(b, n, m, unknown, known, dist2, idx); +}; +#endif + +void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor, + Tensor dist2_tensor, Tensor idx_tensor, int b, int n, + int m) { + if (unknown_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + three_nn_forward_cuda(b, n, m, unknown_tensor, known_tensor, dist2_tensor, + idx_tensor); +#else + AT_ERROR("three_nn is not compiled with GPU support"); +#endif + } else { + AT_ERROR("three_nn is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/three_nn_parrots.cpp b/mmcv/ops/csrc/parrots/three_nn_parrots.cpp new file mode 100644 index 0000000000..c28c7d216c --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_nn_parrots.cpp @@ -0,0 +1,35 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "three_nn_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void three_nn_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, n, m; + SSAttrs(attr).get("b", b).get("n", n).get("m", m).done(); + + auto unknown_tensor = buildATensor(ctx, ins[0]); + auto known_tensor = buildATensor(ctx, ins[1]); + + auto dist2_tensor = buildATensor(ctx, outs[0]); + auto idx_tensor = buildATensor(ctx, outs[1]); + + three_nn_forward(unknown_tensor, known_tensor, dist2_tensor, idx_tensor, b, n, + m); +} + +PARROTS_EXTENSION_REGISTER(three_nn_forward) + .attr("b") + .attr("n") + .attr("m") + .input(2) + .output(2) + .apply(three_nn_forward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/three_nn_pytorch.h b/mmcv/ops/csrc/parrots/three_nn_pytorch.h new file mode 100644 index 0000000000..6574fba091 --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_nn_pytorch.h @@ -0,0 +1,10 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef THREE_NN_PYTORCH_H +#define THREE_NN_PYTORCH_H +#include +using namespace at; + +void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor, + Tensor dist2_tensor, Tensor idx_tensor, int b, int n, + int m); +#endif // THREE_NN_PYTORCH_H diff --git a/mmcv/ops/csrc/pytorch/assign_score_withk.cpp b/mmcv/ops/csrc/pytorch/assign_score_withk.cpp index 36bd16432f..d35fd24795 100644 --- a/mmcv/ops/csrc/pytorch/assign_score_withk.cpp +++ b/mmcv/ops/csrc/pytorch/assign_score_withk.cpp @@ -34,10 +34,10 @@ void assign_score_withk_backward_cuda( }; #endif -void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O, - int aggregate, const Tensor& points, - const Tensor& centers, const Tensor& scores, - const Tensor& knn_idx, Tensor& output) { +void assign_score_withk_forward(const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, + Tensor& output, int B, int N0, int N1, int M, + int K, int O, int aggregate) { if (points.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CONTIGUOUS(points); @@ -56,12 +56,12 @@ void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O, } } -void assign_score_withk_backward(int B, int N0, int N1, int M, int K, int O, - int aggregate, const Tensor& grad_out, - const Tensor& points, const Tensor& centers, - const Tensor& scores, const Tensor& knn_idx, - Tensor& grad_points, Tensor& grad_centers, - Tensor& grad_scores) { +void assign_score_withk_backward(const Tensor& grad_out, const Tensor& points, + const Tensor& centers, const Tensor& scores, + const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores, + int B, int N0, int N1, int M, int K, int O, + int aggregate) { if (grad_points.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CONTIGUOUS(grad_out); diff --git a/mmcv/ops/csrc/pytorch/ball_query.cpp b/mmcv/ops/csrc/pytorch/ball_query.cpp index 0a0892ba16..fc2709f0db 100644 --- a/mmcv/ops/csrc/pytorch/ball_query.cpp +++ b/mmcv/ops/csrc/pytorch/ball_query.cpp @@ -18,9 +18,9 @@ void ball_query_forward_cuda(int b, int n, int m, float min_radius, }; #endif -void ball_query_forward(int b, int n, int m, float min_radius, float max_radius, - int nsample, Tensor new_xyz_tensor, Tensor xyz_tensor, - Tensor idx_tensor) { +void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, + Tensor idx_tensor, int b, int n, int m, + float min_radius, float max_radius, int nsample) { if (new_xyz_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(new_xyz_tensor); diff --git a/mmcv/ops/csrc/pytorch/furthest_point_sample.cpp b/mmcv/ops/csrc/pytorch/furthest_point_sample.cpp index a7bc060a82..e3ec99a82c 100644 --- a/mmcv/ops/csrc/pytorch/furthest_point_sample.cpp +++ b/mmcv/ops/csrc/pytorch/furthest_point_sample.cpp @@ -25,8 +25,8 @@ void furthest_point_sampling_with_dist_forward_cuda(int b, int n, int m, } #endif -void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor, - Tensor temp_tensor, Tensor idx_tensor) { +void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor, + Tensor idx_tensor, int b, int n, int m) { if (points_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA const float *points = points_tensor.data_ptr(); @@ -41,10 +41,10 @@ void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor, } } -void furthest_point_sampling_with_dist_forward(int b, int n, int m, - Tensor points_tensor, +void furthest_point_sampling_with_dist_forward(Tensor points_tensor, Tensor temp_tensor, - Tensor idx_tensor) { + Tensor idx_tensor, int b, int n, + int m) { if (points_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA const float *points = points_tensor.data(); diff --git a/mmcv/ops/csrc/pytorch/gather_points.cpp b/mmcv/ops/csrc/pytorch/gather_points.cpp index a56e933c8d..3ab93b600f 100644 --- a/mmcv/ops/csrc/pytorch/gather_points.cpp +++ b/mmcv/ops/csrc/pytorch/gather_points.cpp @@ -24,9 +24,9 @@ void gather_points_backward_cuda(int b, int c, int n, int npoints, }; #endif -void gather_points_forward(int b, int c, int n, int npoints, - Tensor points_tensor, Tensor idx_tensor, - Tensor out_tensor) { +void gather_points_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor, int b, int c, int n, + int npoints) { if (points_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA gather_points_forward_cuda(b, c, n, npoints, points_tensor, idx_tensor, @@ -39,9 +39,9 @@ void gather_points_forward(int b, int c, int n, int npoints, } } -void gather_points_backward(int b, int c, int n, int npoints, - Tensor grad_out_tensor, Tensor idx_tensor, - Tensor grad_points_tensor) { +void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor, int b, int c, int n, + int npoints) { if (grad_out_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA gather_points_backward_cuda(b, c, n, npoints, grad_out_tensor, idx_tensor, diff --git a/mmcv/ops/csrc/pytorch/knn.cpp b/mmcv/ops/csrc/pytorch/knn.cpp index fbbbfc8f2b..55105eb019 100644 --- a/mmcv/ops/csrc/pytorch/knn.cpp +++ b/mmcv/ops/csrc/pytorch/knn.cpp @@ -14,9 +14,8 @@ void knn_forward_cuda(int b, int n, int m, int nsample, const Tensor xyz, } #endif -void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor, - Tensor new_xyz_tensor, Tensor idx_tensor, - Tensor dist2_tensor) { +void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, + Tensor dist2_tensor, int b, int n, int m, int nsample) { if (new_xyz_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(new_xyz_tensor); diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 8f52e26e82..1845737f3d 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -4,17 +4,17 @@ std::string get_compiler_version(); std::string get_compiling_cuda_version(); -void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O, - int aggregate, const Tensor &points, - const Tensor ¢ers, const Tensor &scores, - const Tensor &knn_idx, Tensor &output); - -void assign_score_withk_backward(int B, int N0, int N1, int M, int K, int O, - int aggregate, const Tensor &grad_out, - const Tensor &points, const Tensor ¢ers, - const Tensor &scores, const Tensor &knn_idx, - Tensor &grad_points, Tensor &grad_centers, - Tensor &grad_scores); +void assign_score_withk_forward(const Tensor &points, const Tensor ¢ers, + const Tensor &scores, const Tensor &knn_idx, + Tensor &output, int B, int N0, int N1, int M, + int K, int O, int aggregate); + +void assign_score_withk_backward(const Tensor &grad_out, const Tensor &points, + const Tensor ¢ers, const Tensor &scores, + const Tensor &knn_idx, Tensor &grad_points, + Tensor &grad_centers, Tensor &grad_scores, + int B, int N0, int N1, int M, int K, int O, + int aggregate); void carafe_naive_forward(Tensor features, Tensor masks, Tensor output, int kernel_size, int group_size, int scale_factor); @@ -76,13 +76,12 @@ void group_points_backward(int b, int c, int n, int npoints, int nsample, void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag); -void gather_points_forward(int b, int c, int n, int npoints, - Tensor points_tensor, Tensor idx_tensor, - Tensor out_tensor); +void gather_points_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor, int b, int c, int n, int npoints); -void gather_points_backward(int b, int c, int n, int npoints, - Tensor grad_out_tensor, Tensor idx_tensor, - Tensor grad_points_tensor); +void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor, int b, int c, int n, + int npoints); void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha); @@ -97,22 +96,23 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight, Tensor buff, Tensor grad_input, float gamma, float alpha); -void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor, - Tensor idx_tensor, Tensor weight_tensor, - Tensor out_tensor); +void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor out_tensor, int b, + int c, int m, int n); -void three_interpolate_backward(int b, int c, int n, int m, - Tensor grad_out_tensor, Tensor idx_tensor, - Tensor weight_tensor, - Tensor grad_points_tensor); +void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor grad_points_tensor, + int b, int c, int n, int m); -void three_nn_forward(int b, int n, int m, Tensor unknown_tensor, - Tensor known_tensor, Tensor dist2_tensor, - Tensor idx_tensor); +void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor, + Tensor dist2_tensor, Tensor idx_tensor, int b, int n, + int m); void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, const bool aligned, const int offset); +void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, + Tensor dist2_tensor, int b, int n, int m, int nsample); void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b, Tensor ans_overlap); @@ -124,16 +124,13 @@ int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh); int iou3d_nms_normal_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh); -void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor, - Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor); +void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor, + Tensor idx_tensor, int b, int n, int m); -void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor, - Tensor temp_tensor, Tensor idx_tensor); - -void furthest_point_sampling_with_dist_forward(int b, int n, int m, - Tensor points_tensor, +void furthest_point_sampling_with_dist_forward(Tensor points_tensor, Tensor temp_tensor, - Tensor idx_tensor); + Tensor idx_tensor, int b, int n, + int m); void masked_im2col_forward(const Tensor im, const Tensor mask_h_idx, const Tensor mask_w_idx, Tensor col, @@ -238,9 +235,9 @@ void tin_shift_forward(Tensor input, Tensor shift, Tensor output); void tin_shift_backward(Tensor grad_output, Tensor shift, Tensor grad_input); -void ball_query_forward(int b, int n, int m, float min_radius, float max_radius, - int nsample, Tensor new_xyz_tensor, Tensor xyz_tensor, - Tensor idx_tensor); +void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, + Tensor idx_tensor, int b, int n, int m, + float min_radius, float max_radius, int nsample); Tensor bottom_pool_forward(Tensor input); @@ -352,32 +349,31 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"), py::arg("scale")); m.def("gather_points_forward", &gather_points_forward, - "gather_points_forward", py::arg("b"), py::arg("c"), py::arg("n"), - py::arg("npoints"), py::arg("points_tensor"), py::arg("idx_tensor"), - py::arg("out_tensor")); + "gather_points_forward", py::arg("points_tensor"), + py::arg("idx_tensor"), py::arg("out_tensor"), py::arg("b"), + py::arg("c"), py::arg("n"), py::arg("npoints")); m.def("gather_points_backward", &gather_points_backward, - "gather_points_backward", py::arg("b"), py::arg("c"), py::arg("n"), - py::arg("npoints"), py::arg("grad_out_tensor"), py::arg("idx_tensor"), - py::arg("grad_points_tensor")); + "gather_points_backward", py::arg("grad_out_tensor"), + py::arg("idx_tensor"), py::arg("grad_points_tensor"), py::arg("b"), + py::arg("c"), py::arg("n"), py::arg("npoints")); m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); m.def("get_compiling_cuda_version", &get_compiling_cuda_version, "get_compiling_cuda_version"); m.def("assign_score_withk_forward", &assign_score_withk_forward, - "assign_score_withk_forward", py::arg("B"), py::arg("N0"), - py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"), - py::arg("aggregate"), py::arg("points"), py::arg("centers"), - py::arg("scores"), py::arg("knn_idx"), py::arg("output")); + "assign_score_withk_forward", py::arg("points"), py::arg("centers"), + py::arg("scores"), py::arg("knn_idx"), py::arg("output"), py::arg("B"), + py::arg("N0"), py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"), + py::arg("aggregate")); m.def("assign_score_withk_backward", &assign_score_withk_backward, - "assign_score_withk_backward", py::arg("B"), py::arg("N0"), - py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"), - py::arg("aggregate"), py::arg("grad_out"), py::arg("points"), + "assign_score_withk_backward", py::arg("grad_out"), py::arg("points"), py::arg("centers"), py::arg("scores"), py::arg("knn_idx"), - py::arg("grad_points"), py::arg("grad_centers"), - py::arg("grad_scores")); - m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"), - py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"), + py::arg("grad_points"), py::arg("grad_centers"), py::arg("grad_scores"), + py::arg("B"), py::arg("N0"), py::arg("N1"), py::arg("M"), py::arg("K"), + py::arg("O"), py::arg("aggregate")); + m.def("knn_forward", &knn_forward, "knn_forward", py::arg("xyz_tensor"), py::arg("new_xyz_tensor"), py::arg("idx_tensor"), - py::arg("dist2_tensor")); + py::arg("dist2_tensor"), py::arg("b"), py::arg("n"), py::arg("m"), + py::arg("nsample")); m.def("carafe_naive_forward", &carafe_naive_forward, "carafe_naive_forward", py::arg("features"), py::arg("masks"), py::arg("output"), py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); @@ -447,17 +443,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("weight"), py::arg("buff"), py::arg("grad_input"), py::arg("gamma"), py::arg("alpha")); m.def("three_interpolate_forward", &three_interpolate_forward, - "three_interpolate_forward", py::arg("b"), py::arg("c"), py::arg("m"), - py::arg("n"), py::arg("points_tensor"), py::arg("idx_tensor"), - py::arg("weight_tensor"), py::arg("out_tensor")); + "three_interpolate_forward", py::arg("points_tensor"), + py::arg("idx_tensor"), py::arg("weight_tensor"), py::arg("out_tensor"), + py::arg("b"), py::arg("c"), py::arg("m"), py::arg("n")); m.def("three_interpolate_backward", &three_interpolate_backward, - "three_interpolate_backward", py::arg("b"), py::arg("c"), py::arg("n"), - py::arg("m"), py::arg("grad_out_tensor"), py::arg("idx_tensor"), - py::arg("weight_tensor"), py::arg("grad_points_tensor")); - m.def("three_nn_forward", &three_nn_forward, "three_nn_forward", py::arg("b"), - py::arg("n"), py::arg("m"), py::arg("unknown_tensor"), - py::arg("known_tensor"), py::arg("dist2_tensor"), - py::arg("idx_tensor")); + "three_interpolate_backward", py::arg("grad_out_tensor"), + py::arg("idx_tensor"), py::arg("weight_tensor"), + py::arg("grad_points_tensor"), py::arg("b"), py::arg("c"), py::arg("n"), + py::arg("m")); + m.def("three_nn_forward", &three_nn_forward, "three_nn_forward", + py::arg("unknown_tensor"), py::arg("known_tensor"), + py::arg("dist2_tensor"), py::arg("idx_tensor"), py::arg("b"), + py::arg("n"), py::arg("m")); m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), py::arg("aligned"), py::arg("offset")); @@ -485,14 +482,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "iou3d_nms_normal_forward", py::arg("boxes"), py::arg("keep"), py::arg("nms_overlap_thresh")); m.def("furthest_point_sampling_forward", &furthest_point_sampling_forward, - "furthest_point_sampling_forward", py::arg("b"), py::arg("n"), - py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"), - py::arg("idx_tensor")); + "furthest_point_sampling_forward", py::arg("points_tensor"), + py::arg("temp_tensor"), py::arg("idx_tensor"), py::arg("b"), + py::arg("n"), py::arg("m")); m.def("furthest_point_sampling_with_dist_forward", &furthest_point_sampling_with_dist_forward, - "furthest_point_sampling_with_dist_forward", py::arg("b"), py::arg("n"), - py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"), - py::arg("idx_tensor")); + "furthest_point_sampling_with_dist_forward", py::arg("points_tensor"), + py::arg("temp_tensor"), py::arg("idx_tensor"), py::arg("b"), + py::arg("n"), py::arg("m")); m.def("masked_im2col_forward", &masked_im2col_forward, "masked_im2col_forward", py::arg("im"), py::arg("mask_h_idx"), py::arg("mask_w_idx"), py::arg("col"), py::arg("kernel_h"), @@ -609,9 +606,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), py::arg("iou_threshold"), py::arg("multi_label")); m.def("ball_query_forward", &ball_query_forward, "ball_query_forward", + py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"), py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"), - py::arg("max_radius"), py::arg("nsample"), py::arg("new_xyz_tensor"), - py::arg("xyz_tensor"), py::arg("idx_tensor")); + py::arg("max_radius"), py::arg("nsample")); m.def("roi_align_rotated_forward", &roi_align_rotated_forward, "roi_align_rotated forward", py::arg("input"), py::arg("rois"), py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"), @@ -657,6 +654,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "backward function of border_align", py::arg("grad_output"), py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), py::arg("pool_size")); + m.def("correlation_forward", &correlation_forward, "Correlation forward", + py::arg("input1"), py::arg("input2"), py::arg("output"), py::arg("kH"), + py::arg("kW"), py::arg("patchH"), py::arg("patchW"), py::arg("padH"), + py::arg("padW"), py::arg("dilationH"), py::arg("dilationW"), + py::arg("dilation_patchH"), py::arg("dilation_patchW"), py::arg("dH"), + py::arg("dW")); + m.def("correlation_backward", &correlation_backward, "Correlation backward", + py::arg("grad_output"), py::arg("input1"), py::arg("input2"), + py::arg("grad_input1"), py::arg("grad_input2"), py::arg("kH"), + py::arg("kW"), py::arg("patchH"), py::arg("patchW"), py::arg("padH"), + py::arg("padW"), py::arg("dilationH"), py::arg("dilationW"), + py::arg("dilation_patchH"), py::arg("dilation_patchW"), py::arg("dH"), + py::arg("dW")); m.def("points_in_boxes_cpu_forward", &points_in_boxes_cpu_forward, "points_in_boxes_cpu_forward", py::arg("boxes_tensor"), py::arg("pts_tensor"), py::arg("pts_indices_tensor")); @@ -674,6 +684,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "roiaware_pool3d_backward", py::arg("pts_idx_of_voxels"), py::arg("argmax"), py::arg("grad_out"), py::arg("grad_in"), py::arg("pool_method")); - m.def("correlation_forward", &correlation_forward, "Correlation forward"); - m.def("correlation_backward", &correlation_backward, "Correlation backward"); } diff --git a/mmcv/ops/csrc/pytorch/three_interpolate.cpp b/mmcv/ops/csrc/pytorch/three_interpolate.cpp index 71a7e09cec..dbbcd995d0 100644 --- a/mmcv/ops/csrc/pytorch/three_interpolate.cpp +++ b/mmcv/ops/csrc/pytorch/three_interpolate.cpp @@ -30,9 +30,9 @@ void three_interpolate_backward_cuda(int b, int c, int n, int m, }; #endif -void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor, - Tensor idx_tensor, Tensor weight_tensor, - Tensor out_tensor) { +void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor out_tensor, int b, + int c, int m, int n) { if (points_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA three_interpolate_forward_cuda(b, c, m, n, points_tensor, idx_tensor, @@ -45,10 +45,9 @@ void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor, } } -void three_interpolate_backward(int b, int c, int n, int m, - Tensor grad_out_tensor, Tensor idx_tensor, - Tensor weight_tensor, - Tensor grad_points_tensor) { +void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor grad_points_tensor, + int b, int c, int n, int m) { if (grad_out_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA three_interpolate_backward_cuda(b, c, n, m, grad_out_tensor, idx_tensor, diff --git a/mmcv/ops/csrc/pytorch/three_nn.cpp b/mmcv/ops/csrc/pytorch/three_nn.cpp index cba70746c6..158ac00231 100644 --- a/mmcv/ops/csrc/pytorch/three_nn.cpp +++ b/mmcv/ops/csrc/pytorch/three_nn.cpp @@ -14,9 +14,9 @@ void three_nn_forward_cuda(int b, int n, int m, const Tensor unknown, }; #endif -void three_nn_forward(int b, int n, int m, Tensor unknown_tensor, - Tensor known_tensor, Tensor dist2_tensor, - Tensor idx_tensor) { +void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor, + Tensor dist2_tensor, Tensor idx_tensor, int b, int n, + int m) { if (unknown_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA three_nn_forward_cuda(b, n, m, unknown_tensor, known_tensor, dist2_tensor, diff --git a/mmcv/ops/furthest_point_sample.py b/mmcv/ops/furthest_point_sample.py index 11cf5fbf51..374b7a878f 100644 --- a/mmcv/ops/furthest_point_sample.py +++ b/mmcv/ops/furthest_point_sample.py @@ -30,9 +30,16 @@ def forward(ctx, points_xyz: torch.Tensor, output = torch.cuda.IntTensor(B, num_points) temp = torch.cuda.FloatTensor(B, N).fill_(1e10) - ext_module.furthest_point_sampling_forward(B, N, num_points, - points_xyz, temp, output) - ctx.mark_non_differentiable(output) + ext_module.furthest_point_sampling_forward( + points_xyz, + temp, + output, + b=B, + n=N, + m=num_points, + ) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(output) return output @staticmethod @@ -62,8 +69,9 @@ def forward(ctx, points_dist: torch.Tensor, temp = points_dist.new_zeros([B, N]).fill_(1e10) ext_module.furthest_point_sampling_with_dist_forward( - B, N, num_points, points_dist, temp, output) - ctx.mark_non_differentiable(output) + points_dist, temp, output, b=B, n=N, m=num_points) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(output) return output @staticmethod diff --git a/mmcv/ops/gather_points.py b/mmcv/ops/gather_points.py index 67c6c3f8e0..f52f1677d8 100644 --- a/mmcv/ops/gather_points.py +++ b/mmcv/ops/gather_points.py @@ -28,11 +28,12 @@ def forward(ctx, features: torch.Tensor, _, C, N = features.size() output = torch.cuda.FloatTensor(B, C, npoint) - ext_module.gather_points_forward(B, C, N, npoint, features, indices, - output) + ext_module.gather_points_forward( + features, indices, output, b=B, c=C, n=N, npoints=npoint) ctx.for_backwards = (indices, C, N) - ctx.mark_non_differentiable(indices) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(indices) return output @staticmethod @@ -42,8 +43,14 @@ def backward(ctx, grad_out): grad_features = torch.cuda.FloatTensor(B, C, N).zero_() grad_out_data = grad_out.data.contiguous() - ext_module.gather_points_backward(B, C, N, npoint, grad_out_data, idx, - grad_features.data) + ext_module.gather_points_backward( + grad_out_data, + idx, + grad_features.data, + b=B, + c=C, + n=N, + npoints=npoint) return grad_features, None diff --git a/mmcv/ops/knn.py b/mmcv/ops/knn.py index e3428a91c5..f335785036 100644 --- a/mmcv/ops/knn.py +++ b/mmcv/ops/knn.py @@ -61,10 +61,12 @@ def forward(ctx, idx = center_xyz.new_zeros((B, npoint, k)).int() dist2 = center_xyz.new_zeros((B, npoint, k)).float() - ext_module.knn_forward(B, N, npoint, k, xyz, center_xyz, idx, dist2) + ext_module.knn_forward( + xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k) # idx shape to [B, k, npoint] idx = idx.transpose(2, 1).contiguous() - ctx.mark_non_differentiable(idx) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) return idx @staticmethod diff --git a/mmcv/ops/three_interpolate.py b/mmcv/ops/three_interpolate.py index 48aa8c5a39..203f47f05d 100644 --- a/mmcv/ops/three_interpolate.py +++ b/mmcv/ops/three_interpolate.py @@ -39,8 +39,8 @@ def forward(ctx, features: torch.Tensor, indices: torch.Tensor, ctx.three_interpolate_for_backward = (indices, weight, m) output = torch.cuda.FloatTensor(B, c, n) - ext_module.three_interpolate_forward(B, c, m, n, features, indices, - weight, output) + ext_module.three_interpolate_forward( + features, indices, weight, output, b=B, c=c, m=m, n=n) return output @staticmethod @@ -60,8 +60,8 @@ def backward( grad_features = torch.cuda.FloatTensor(B, c, m).zero_() grad_out_data = grad_out.data.contiguous() - ext_module.three_interpolate_backward(B, c, n, m, grad_out_data, idx, - weight, grad_features.data) + ext_module.three_interpolate_backward( + grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m) return grad_features, None, None diff --git a/mmcv/ops/three_nn.py b/mmcv/ops/three_nn.py index 459f611a4e..2b01047a12 100644 --- a/mmcv/ops/three_nn.py +++ b/mmcv/ops/three_nn.py @@ -37,9 +37,9 @@ def forward(ctx, target: torch.Tensor, dist2 = torch.cuda.FloatTensor(B, N, 3) idx = torch.cuda.IntTensor(B, N, 3) - ext_module.three_nn_forward(B, N, m, target, source, dist2, idx) - - ctx.mark_non_differentiable(idx) + ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) return torch.sqrt(dist2), idx From 01bc35e44da66ecfa2b59cfe527270e97424b93d Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sat, 23 Oct 2021 20:51:44 +0800 Subject: [PATCH 06/30] [Feature] Loading objects from different backends and dumping objects to different backends (#1330) * [Feature] Choose storage backend by the prefix of filepath * refactor FileClient and add unittest * support loading from different backends * polish docstring * fix unittet * rename attribute str_like_obj to is_str_like_obj * add infer_client method * add check_exist method * rename var client to file_client * polish docstring * add join_paths method * remove join_paths and add _format_path * enhance unittest * refactor unittest * singleton pattern * fix test_clientio.py * deprecate CephBackend * enhance docstring * refactor unittest for petrel * refactor unittest for disk backend * update io.md * add concat_paths method * improve docstring * improve docstring * add isdir and copyfile for file backend * delete copyfile and add get_local_path * remove isdir method of petrel * fix typo * add comment and polish docstring * polish docstring * rename _path_mapping to _map_path * polish docstring and fix typo * refactor get_local_path * add list_dir_or_file for FileClient * add list_dir_or_file for PetrelBackend * fix windows ci * Add return docstring * polish docstring * fix typo * fix typo * deprecate the conversion from Path to str * add docs for loading checkpoints with FileClient * refactor map_path * add _ensure_methods to ensure methods have been implemented * fix list_dir_or_file * rename _ensure_method_implemented to has_method --- docs/understand_mmcv/io.md | 128 +++- docs_zh_CN/understand_mmcv/io.md | 129 +++- mmcv/fileio/file_client.py | 875 ++++++++++++++++++++++++- mmcv/fileio/handlers/base.py | 6 + mmcv/fileio/handlers/pickle_handler.py | 2 + mmcv/fileio/io.py | 49 +- mmcv/fileio/parse.py | 53 +- mmcv/image/photometric.py | 4 +- mmcv/utils/__init__.py | 6 +- mmcv/utils/misc.py | 13 + tests/test_fileclient.py | 600 ++++++++++++++++- tests/test_fileio.py | 70 +- tests/test_utils/test_misc.py | 16 + 13 files changed, 1860 insertions(+), 91 deletions(-) diff --git a/docs/understand_mmcv/io.md b/docs/understand_mmcv/io.md index 50314d13d0..f6c28dd425 100644 --- a/docs/understand_mmcv/io.md +++ b/docs/understand_mmcv/io.md @@ -2,11 +2,17 @@ This module provides two universal API to load and dump files of different formats. +```{note} +Since v1.3.16, the IO modules support loading (dumping) data from (to) different backends, respectively. More details are in PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330). +``` + ### Load and dump data `mmcv` provides a universal api for loading and dumping data, currently supported formats are json, yaml and pickle. +#### Load from disk or dump to disk + ```python import mmcv @@ -29,6 +35,20 @@ with open('test.yaml', 'w') as f: data = mmcv.dump(data, f, file_format='yaml') ``` +#### Load from other backends or dump to other backends + +```python +import mmcv + +# load data from a file +data = mmcv.load('s3://bucket-name/test.json') +data = mmcv.load('s3://bucket-name/test.yaml') +data = mmcv.load('s3://bucket-name/test.pkl') + +# dump data to a file with a filename (infer format from file extension) +mmcv.dump(data, 's3://bucket-name/out.pkl') +``` + It is also very convenient to extend the api to support more file formats. All you need to do is to write a file handler inherited from `BaseFileHandler` and register it with one or several file formats. @@ -92,7 +112,9 @@ d e ``` -Then use `list_from_file` to load the list from a.txt. +#### Load from disk + +Use `list_from_file` to load the list from a.txt. ```python >>> mmcv.list_from_file('a.txt') @@ -113,7 +135,7 @@ For example `b.txt` is a text file with 3 lines. 3 panda ``` -Then use `dict_from_file` to load the dict from `b.txt` . +Then use `dict_from_file` to load the dict from `b.txt`. ```python >>> mmcv.dict_from_file('b.txt') @@ -121,3 +143,105 @@ Then use `dict_from_file` to load the dict from `b.txt` . >>> mmcv.dict_from_file('b.txt', key_type=int) {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} ``` + +#### Load from other backends + +Use `list_from_file` to load the list from `s3://bucket-name/a.txt`. + +```python +>>> mmcv.list_from_file('s3://bucket-name/a.txt') +['a', 'b', 'c', 'd', 'e'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', offset=2) +['c', 'd', 'e'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', max_num=2) +['a', 'b'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', prefix='/mnt/') +['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e'] +``` + +Use `dict_from_file` to load the dict from `s3://bucket-name/b.txt`. + +```python +>>> mmcv.dict_from_file('s3://bucket-name/b.txt') +{'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} +>>> mmcv.dict_from_file('s3://bucket-name/b.txt', key_type=int) +{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} +``` + +### Load and dump checkpoints + +#### Load checkpoints from disk or save to disk + +We can read the checkpoints from disk or save to disk in the following way. + +```python +import torch + +filepath1 = '/path/of/your/checkpoint1.pth' +filepath2 = '/path/of/your/checkpoint2.pth' +# read from filepath1 +checkpoint = torch.load(filepath1) +# save to filepath2 +torch.save(checkpoint, filepath2) +``` + +MMCV provides many backends. `HardDiskBackend` is one of them and we can use it to read or save checkpoints. + +```python +import io +from mmcv.fileio.file_client import HardDiskBackend + +disk_backend = HardDiskBackend() +with io.BytesIO(disk_backend.get(filepath1)) as buffer: + checkpoint = torch.load(buffer) +with io.BytesIO() as buffer: + torch.save(checkpoint, f) + disk_backend.put(f.getvalue(), filepath2) +``` + +If we want to implement an interface which automatically select the corresponding +backend based on the file path, we can use the `FileClient`. +For example, we want to implement two methods for reading checkpoints as well as saving checkpoints, +which need to support different types of file paths, either disk paths, network paths or other paths. + +```python +from mmcv.fileio.file_client import FileClient + +def load_checkpoint(path): + file_client = FileClient.infer(uri=path) + with io.BytesIO(file_client.get(path)) as buffer: + checkpoint = torch.load(buffer) + return checkpoint + +def save_checkpoint(checkpoint, path): + with io.BytesIO() as buffer: + torch.save(checkpoint, buffer) + file_client.put(buffer.getvalue(), path) + +file_client = FileClient.infer_client(uri=filepath1) +checkpoint = load_checkpoint(filepath1) +save_checkpoint(checkpoint, filepath2) +``` + +#### Load checkpoints from the Internet + +```{note} +Currently, it only supports reading checkpoints from the Internet, and does not support saving checkpoints to the Internet. +``` + +```python +import io +import torch +from mmcv.fileio.file_client import HTTPBackend, FileClient + +filepath = 'http://path/of/your/checkpoint.pth' +checkpoint = torch.utils.model_zoo.load_url(filepath) + +http_backend = HTTPBackend() +with io.BytesIO(http_backend.get(filepath)) as buffer: + checkpoint = torch.load(buffer) + +file_client = FileClient.infer_client(uri=filepath) +with io.BytesIO(file_client.get(filepath)) as buffer: + checkpoint = torch.load(buffer) +``` diff --git a/docs_zh_CN/understand_mmcv/io.md b/docs_zh_CN/understand_mmcv/io.md index 8d3844f77c..0e5002f828 100644 --- a/docs_zh_CN/understand_mmcv/io.md +++ b/docs_zh_CN/understand_mmcv/io.md @@ -2,10 +2,16 @@ 文件输入输出模块提供了两个通用的 API 接口用于读取和保存不同格式的文件。 +```{note} +在 v1.3.16 及之后的版本中,IO 模块支持从不同后端读取数据并支持将数据至不同后端。更多细节请访问 PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330)。 +``` + ### 读取和保存数据 `mmcv` 提供了一个通用的 api 用于读取和保存数据,目前支持的格式有 json、yaml 和 pickle。 +#### 从硬盘读取数据或者将数据保存至硬盘 + ```python import mmcv @@ -28,6 +34,20 @@ with open('test.yaml', 'w') as f: data = mmcv.dump(data, f, file_format='yaml') ``` +#### 从其他后端加载或者保存至其他后端 + +```python +import mmcv + +# 从 s3 文件读取数据 +data = mmcv.load('s3://bucket-name/test.json') +data = mmcv.load('s3://bucket-name/test.yaml') +data = mmcv.load('s3://bucket-name/test.pkl') + +# 将数据保存至 s3 文件 (根据文件名后缀反推文件类型) +mmcv.dump(data, 's3://bucket-name/out.pkl') +``` + 我们提供了易于拓展的方式以支持更多的文件格式。我们只需要创建一个继承自 `BaseFileHandler` 的 文件句柄类并将其注册到 `mmcv` 中即可。句柄类至少需要重写三个方法。 @@ -49,7 +69,7 @@ class TxtHandler1(mmcv.BaseFileHandler): return str(obj) ``` -举 `PickleHandler` 为例。 +以 `PickleHandler` 为例 ```python import pickle @@ -87,8 +107,9 @@ c d e ``` +#### 从硬盘读取 -使用 `list_from_file` 读取 `a.txt` 。 +使用 `list_from_file` 读取 `a.txt` ```python >>> mmcv.list_from_file('a.txt') @@ -101,7 +122,7 @@ e ['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e'] ``` -同样, `b.txt` 也是文本文件,一共有3行内容。 +同样, `b.txt` 也是文本文件,一共有3行内容 ``` 1 cat @@ -109,7 +130,7 @@ e 3 panda ``` -使用 `dict_from_file` 读取 `b.txt` 。 +使用 `dict_from_file` 读取 `b.txt` ```python >>> mmcv.dict_from_file('b.txt') @@ -117,3 +138,103 @@ e >>> mmcv.dict_from_file('b.txt', key_type=int) {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} ``` + +#### 从其他后端读取 + +使用 `list_from_file` 读取 `s3://bucket-name/a.txt` + +```python +>>> mmcv.list_from_file('s3://bucket-name/a.txt') +['a', 'b', 'c', 'd', 'e'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', offset=2) +['c', 'd', 'e'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', max_num=2) +['a', 'b'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', prefix='/mnt/') +['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e'] +``` + +使用 `dict_from_file` 读取 `b.txt` + +```python +>>> mmcv.dict_from_file('s3://bucket-name/b.txt') +{'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} +>>> mmcv.dict_from_file('s3://bucket-name/b.txt', key_type=int) +{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} +``` + +### 读取和保存权重文件 + +#### 从硬盘读取权重文件或者将权重文件保存至硬盘 + +我们可以通过下面的方式从磁盘读取权重文件或者将权重文件保存至磁盘 + +```python +import torch + +filepath1 = '/path/of/your/checkpoint1.pth' +filepath2 = '/path/of/your/checkpoint2.pth' +# 从 filepath1 读取权重文件 +checkpoint = torch.load(filepath1) +# 将权重文件保存至 filepath2 +torch.save(checkpoint, filepath2) +``` + +MMCV 提供了很多后端,`HardDiskBackend` 是其中一个,我们可以通过它来读取或者保存权重文件。 + +```python +import io +from mmcv.fileio.file_client import HardDiskBackend + +disk_backend = HardDiskBackend() +with io.BytesIO(disk_backend.get(filepath1)) as buffer: + checkpoint = torch.load(buffer) +with io.BytesIO() as buffer: + torch.save(checkpoint, f) + disk_backend.put(f.getvalue(), filepath2) +``` + +如果我们想在接口中实现根据文件路径自动选择对应的后端,我们可以使用 `FileClient`。 +例如,我们想实现两个方法,分别是读取权重以及保存权重,它们需支持不同类型的文件路径,可以是磁盘路径,也可以是网络路径或者其他路径。 + +```python +from mmcv.fileio.file_client import FileClient + +def load_checkpoint(path): + file_client = FileClient.infer(uri=path) + with io.BytesIO(file_client.get(path)) as buffer: + checkpoint = torch.load(buffer) + return checkpoint + +def save_checkpoint(checkpoint, path): + with io.BytesIO() as buffer: + torch.save(checkpoint, buffer) + file_client.put(buffer.getvalue(), path) + +file_client = FileClient.infer_client(uri=filepath1) +checkpoint = load_checkpoint(filepath1) +save_checkpoint(checkpoint, filepath2) +``` + +#### 从网络远端读取权重文件 + +```{note} +目前只支持从网络远端读取权重文件,暂不支持将权重文件写入网络远端 +``` + +```python +import io +import torch +from mmcv.fileio.file_client import HTTPBackend, FileClient + +filepath = 'http://path/of/your/checkpoint.pth' +checkpoint = torch.utils.model_zoo.load_url(filepath) + +http_backend = HTTPBackend() +with io.BytesIO(http_backend.get(filepath)) as buffer: + checkpoint = torch.load(buffer) + +file_client = FileClient.infer_client(uri=filepath) +with io.BytesIO(file_client.get(filepath)) as buffer: + checkpoint = torch.load(buffer) +``` diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index e8a6cbdb08..a6c0f8b89e 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -1,8 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect +import os +import os.path as osp +import re +import tempfile +import warnings from abc import ABCMeta, abstractmethod +from contextlib import contextmanager +from pathlib import Path +from typing import Iterable, Iterator, Optional, Tuple, Union from urllib.request import urlopen +from mmcv.utils.misc import has_method +from mmcv.utils.path import is_filepath + class BaseStorageBackend(metaclass=ABCMeta): """Abstract class of storage backends. @@ -22,12 +33,16 @@ def get_text(self, filepath): class CephBackend(BaseStorageBackend): - """Ceph storage backend. + """Ceph storage backend (for internal use). Args: path_mapping (dict|None): path mapping dict from local path to Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath`` will be replaced by ``dst``. Default: None. + + .. warning:: + :class:`CephBackend` will be deprecated, please use + :class:`PetrelBackend` instead """ def __init__(self, path_mapping=None): @@ -36,6 +51,8 @@ def __init__(self, path_mapping=None): except ImportError: raise ImportError('Please install ceph to enable CephBackend.') + warnings.warn( + 'CephBackend will be deprecated, please use PetrelBackend instead') self._client = ceph.S3Client() assert isinstance(path_mapping, dict) or path_mapping is None self.path_mapping = path_mapping @@ -49,21 +66,36 @@ def get(self, filepath): value_buf = memoryview(value) return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding=None): raise NotImplementedError class PetrelBackend(BaseStorageBackend): """Petrel storage backend (for internal use). + PetrelBackend supports reading and writing data to multiple clusters. + If the file path contains the cluster name, PetrelBackend will read data + from specified cluster or write data to it. Otherwise, PetrelBackend will + access the default cluster. + Args: - path_mapping (dict|None): path mapping dict from local path to Petrel - path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will - be replaced by `dst`. Default: None. - enable_mc (bool): whether to enable memcached support. Default: True. + path_mapping (dict, optional): Path mapping dict from local path to + Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in + ``filepath`` will be replaced by ``dst``. Default: None. + enable_mc (bool, optional): Whether to enable memcached support. + Default: True. + + Examples: + >>> filepath1 = 's3://path/of/file' + >>> filepath2 = 'cluster-name:s3://path/of/file' + >>> client = PetrelBackend() + >>> client.get(filepath1) # get data from default cluster + >>> client.get(filepath2) # get data from 'cluster-name' cluster """ - def __init__(self, path_mapping=None, enable_mc=True): + def __init__(self, + path_mapping: Optional[dict] = None, + enable_mc: bool = True): try: from petrel_client import client except ImportError: @@ -74,17 +106,296 @@ def __init__(self, path_mapping=None, enable_mc=True): assert isinstance(path_mapping, dict) or path_mapping is None self.path_mapping = path_mapping - def get(self, filepath): + def _map_path(self, filepath: Union[str, Path]) -> str: + """Map ``filepath`` to a string path whose prefix will be replaced by + :attr:`self.path_mapping`. + + Args: + filepath (str): Path to be mapped. + """ filepath = str(filepath) if self.path_mapping is not None: for k, v in self.path_mapping.items(): filepath = filepath.replace(k, v) + return filepath + + def _format_path(self, filepath: str) -> str: + """Convert a ``filepath`` to standard format of petrel oss. + + If the ``filepath`` is concatenated by ``os.path.join``, in a Windows + environment, the ``filepath`` will be the format of + 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the + above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. + + Args: + filepath (str): Path to be formatted. + """ + return re.sub(r'\\+', '/', filepath) + + def get(self, filepath: Union[str, Path]) -> memoryview: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + memoryview: A memory view of expected bytes object to avoid + copying. The memoryview object can be converted to bytes by + ``value_buf.tobytes()``. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) value = self._client.Get(filepath) value_buf = memoryview(value) return value_buf - def get_text(self, filepath): - raise NotImplementedError + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return str(self.get(filepath), encoding=encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Save data to a given ``filepath``. + + Args: + obj (bytes): Data to be saved. + filepath (str or Path): Path to write data. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + self._client.put(filepath, obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Save data to a given ``filepath``. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to encode the ``obj``. + Default: 'utf-8'. + """ + self.put(bytes(obj, encoding=encoding), filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + """ + if not has_method(self._client, 'delete'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `delete` method, please use a higher version or dev' + ' branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + self._client.delete(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + if not (has_method(self._client, 'contains') + and has_method(self._client, 'isdir')): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `contains` and `isdir` methods, please use a higher' + 'version or dev branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.contains(filepath) or self._client.isdir(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + if not has_method(self._client, 'isdir'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `isdir` method, please use a higher version or dev' + ' branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + if not has_method(self._client, 'contains'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `contains` method, please use a higher version or ' + 'dev branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.contains(filepath) + + def concat_paths(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result after concatenation. + """ + filepath = self._format_path(self._map_path(filepath)) + if filepath.endswith('/'): + filepath = filepath[:-1] + formatted_paths = [filepath] + for path in filepaths: + formatted_paths.append(self._format_path(self._map_path(path))) + return '/'.join(formatted_paths) + + @contextmanager + def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: + """Download a file from ``filepath`` and return a temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str | Path): Download a file from ``filepath``. + + Examples: + >>> client = PetrelBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with client.get_local_path('s3://path/of/your/file') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one temporary path. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + assert self.isfile(filepath) + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + Petrel has no concept of directories but it simulates the directory + hierarchy in the filesystem through public prefixes. In addition, + if the returned path ends with '/', it means the path is a public + prefix which is a logical directory. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + In addition, the returned path of directory will not contains the + suffix '/' which is consistent with other backends. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + if not has_method(self._client, 'list'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `list` method, please use a higher version or dev' + ' branch instead.')) + + dir_path = self._map_path(dir_path) + dir_path = self._format_path(dir_path) + if list_dir and suffix is not None: + raise TypeError( + '`list_dir` should be False when `suffix` is not None') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + # Petrel's simulated directory hierarchy assumes that directory paths + # should end with `/` + if not dir_path.endswith('/'): + dir_path += '/' + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for path in self._client.list(dir_path): + # the `self.isdir` is not used here to determine whether path + # is a directory, because `self.isdir` relies on + # `self._client.list` + if path.endswith('/'): # a directory path + next_dir_path = self.concat_paths(dir_path, path) + if list_dir: + # get the relative path and exclude the last + # character '/' + rel_dir = next_dir_path[len(root):-1] + yield rel_dir + if recursive: + yield from _list_dir_or_file(next_dir_path, list_dir, + list_file, suffix, + recursive) + else: # a file path + absolute_path = self.concat_paths(dir_path, path) + rel_path = absolute_path[len(root):] + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) class MemcachedBackend(BaseStorageBackend): @@ -121,7 +432,7 @@ def get(self, filepath): value_buf = mc.ConvertBuffer(self._mc_buffer) return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding=None): raise NotImplementedError @@ -173,25 +484,185 @@ def get(self, filepath): value_buf = txn.get(filepath.encode('ascii')) return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding=None): raise NotImplementedError class HardDiskBackend(BaseStorageBackend): """Raw hard disks storage backend.""" - def get(self, filepath): - filepath = str(filepath) + def get(self, filepath: Union[str, Path]) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ with open(filepath, 'rb') as f: value_buf = f.read() return value_buf - def get_text(self, filepath): - filepath = str(filepath) - with open(filepath, 'r') as f: + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + with open(filepath, 'r', encoding=encoding) as f: value_buf = f.read() return value_buf + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + with open(filepath, 'wb') as f: + f.write(obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + """ + os.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return osp.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return osp.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check a ``filepath`` whether it is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return osp.isfile(filepath) + + def concat_paths(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return osp.join(filepath, *filepaths) + + @contextmanager + def get_local_path( + self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]: + """Only for unified API and do nothing.""" + yield filepath + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + if list_dir and suffix is not None: + raise TypeError('`suffix` should be None when `list_dir` is True') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + elif osp.isdir(entry.path): + if list_dir: + rel_dir = osp.relpath(entry.path, root) + yield rel_dir + if recursive: + yield from _list_dir_or_file(entry.path, list_dir, + list_file, suffix, + recursive) + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -200,21 +671,70 @@ def get(self, filepath): value_buf = urlopen(filepath).read() return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): value_buf = urlopen(filepath).read() - return value_buf.decode('utf-8') + return value_buf.decode(encoding) + + @contextmanager + def get_local_path(self, filepath: str) -> Iterable[str]: + """Download a file from ``filepath``. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> client = HTTPBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with client.get_local_path('http://path/of/your/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) class FileClient: - """A general file client to access files in different backend. + """A general file client to access files in different backends. The client loads a file or text in a specified backend from its path - and return it as a binary file. it can also register other backend - accessor with a given name and backend class. + and returns it as a binary or text file. There are two ways to choose a + backend, the name of backend and the prefix of path. Although both of them + can be used to choose a storage backend, ``backend`` has a higher priority + that is if they are all set, the storage backend will be chosen by the + backend argument. If they are all `None`, the disk backend will be chosen. + Note that It can also register other backend accessor with a given name, + prefixes, and backend class. In addition, We use the singleton pattern to + avoid repeated object creation. If the arguments are the same, the same + object is returned. + + Args: + backend (str, optional): The storage backend type. Options are "disk", + "ceph", "memcached", "lmdb", "http" and "petrel". Default: None. + prefix (str, optional): The prefix of the registered storage backend. + Options are "s3", "http", "https". Default: None. + + Examples: + >>> # only set backend + >>> file_client = FileClient(backend='petrel') + >>> # only set prefix + >>> file_client = FileClient(prefix='s3') + >>> # set both backend and prefix but use backend to choose client + >>> file_client = FileClient(backend='petrel', prefix='s3') + >>> # if the arguments are the same, the same object is returned + >>> file_client1 = FileClient(backend='petrel') + >>> file_client1 is file_client + True Attributes: - backend (str): The storage backend type. Options are "disk", "ceph", - "memcached", "lmdb" and "http". client (:obj:`BaseStorageBackend`): The backend object. """ @@ -226,17 +746,117 @@ class FileClient: 'petrel': PetrelBackend, 'http': HTTPBackend, } + # This collection is used to record the overridden backends, and when a + # backend appears in the collection, the singleton pattern is disabled for + # that backend, because if the singleton pattern is used, then the object + # returned will be the backend before overwriting + _overridden_backends = set() + _prefix_to_backends = { + 's3': PetrelBackend, + 'http': HTTPBackend, + 'https': HTTPBackend, + } + _overridden_prefixes = set() + + _instances = {} - def __init__(self, backend='disk', **kwargs): - if backend not in self._backends: + def __new__(cls, backend=None, prefix=None, **kwargs): + if backend is None and prefix is None: + backend = 'disk' + if backend is not None and backend not in cls._backends: raise ValueError( f'Backend {backend} is not supported. Currently supported ones' - f' are {list(self._backends.keys())}') - self.backend = backend - self.client = self._backends[backend](**kwargs) + f' are {list(cls._backends.keys())}') + if prefix is not None and prefix not in cls._prefix_to_backends: + raise ValueError( + f'prefix {prefix} is not supported. Currently supported ones ' + f'are {list(cls._prefix_to_backends.keys())}') + + # concatenate the arguments to a unique key for determining whether + # objects with the same arguments were created + arg_key = f'{backend}:{prefix}' + for key, value in kwargs.items(): + arg_key += f':{key}:{value}' + + # if a backend was overridden, it will create a new object + if (arg_key in cls._instances + and backend not in cls._overridden_backends + and prefix not in cls._overridden_prefixes): + _instance = cls._instances[arg_key] + else: + # create a new object and put it to _instance + _instance = super().__new__(cls) + if backend is not None: + _instance.client = cls._backends[backend](**kwargs) + _instance.backend_name = backend + else: + _instance.client = cls._prefix_to_backends[prefix](**kwargs) + # infer the backend name according to the prefix + for backend_name, backend_cls in cls._backends.items(): + if isinstance(_instance.client, backend_cls): + _instance.backend_name = backend_name + break + cls._instances[arg_key] = _instance + + return _instance + + @staticmethod + def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: + """Parse the prefix of a uri. + + Args: + uri (str | Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> FileClient.parse_uri_prefix('s3://path/of/your/file') + 's3' + + Returns: + str | None: Return the prefix of uri if the uri contains '://' + else ``None``. + """ + assert is_filepath(uri) + uri = str(uri) + if '://' not in uri: + return None + else: + prefix, _ = uri.split('://') + # In the case of PetrelBackend, the prefix may contains the cluster + # name like clusterName:s3 + if ':' in prefix: + _, prefix = prefix.split(':') + return prefix @classmethod - def _register_backend(cls, name, backend, force=False): + def infer_client(cls, + file_client_args: Optional[dict] = None, + uri: Optional[Union[str, Path]] = None) -> 'FileClient': + """Infer a suitable file client based on the URI and arguments. + + Args: + file_client_args (dict, optional): Arguments to instantiate a + FileClient. Default: None. + uri (str | Path, optional): Uri to be parsed that contains the file + prefix. Default: None. + + Examples: + >>> uri = 's3://path/of/your/file' + >>> file_client = FileClient.infer_client(uri=uri) + >>> file_client_args = {'backend': 'petrel'} + >>> file_client = FileClient.infer_client(file_client_args) + + Returns: + FileClient: Instantiated FileClient object. + """ + assert file_client_args is not None or uri is not None + if file_client_args is None: + file_prefix = cls.parse_uri_prefix(uri) # type: ignore + return cls(prefix=file_prefix) + else: + return cls(**file_client_args) + + @classmethod + def _register_backend(cls, name, backend, force=False, prefixes=None): if not isinstance(name, str): raise TypeError('the backend name should be a string, ' f'but got {type(name)}') @@ -251,10 +871,28 @@ def _register_backend(cls, name, backend, force=False): f'{name} is already registered as a storage backend, ' 'add "force=True" if you want to override it') + if name in cls._backends and force: + cls._overridden_backends.add(name) cls._backends[name] = backend + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if prefix not in cls._prefix_to_backends: + cls._prefix_to_backends[prefix] = backend + elif (prefix in cls._prefix_to_backends) and force: + cls._overridden_prefixes.add(prefix) + cls._prefix_to_backends[prefix] = backend + else: + raise KeyError( + f'{prefix} is already registered as a storage backend,' + ' add "force=True" if you want to override it') + @classmethod - def register_backend(cls, name, backend=None, force=False): + def register_backend(cls, name, backend=None, force=False, prefixes=None): """Register a backend to FileClient. This method can be used as a normal class method or a decorator. @@ -292,19 +930,184 @@ def get_text(self, filepath): Defaults to None. force (bool, optional): Whether to override the backend if the name has already been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefixes + of the registered storage backend. Default: None. + `New in version 1.3.15.` """ if backend is not None: - cls._register_backend(name, backend, force=force) + cls._register_backend( + name, backend, force=force, prefixes=prefixes) return def _register(backend_cls): - cls._register_backend(name, backend_cls, force=force) + cls._register_backend( + name, backend_cls, force=force, prefixes=prefixes) return backend_cls return _register - def get(self, filepath): + def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: + """Read data from a given ``filepath`` with 'rb' mode. + + Note: + There are two types of return values for ``get``, one is ``bytes`` + and the other is ``memoryview``. The advantage of using memoryview + is that you can avoid copying, and if you want to convert it to + ``bytes``, you can use ``.tobytes()``. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes | memoryview: Expected bytes object or a memory view of the + bytes object. + """ return self.client.get(filepath) - def get_text(self, filepath): - return self.client.get_text(filepath) + def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return self.client.get_text(filepath, encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + self.client.put(obj, filepath) + + def put_text(self, obj: str, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + `filepath`. Default: 'utf-8'. + """ + self.client.put_text(obj, filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + """ + self.client.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return self.client.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return self.client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return self.client.isfile(filepath) + + def concat_paths(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return self.client.concat_paths(filepath, *filepaths) + + @contextmanager + def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself. + + .. warning:: + ``get_local_path`` is an experimental interface that may change in + the future. + + Args: + filepath (str or Path): Path to be read data. + + Examples: + >>> file_client = FileClient(prefix='s3') + >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one path. + """ + with self.client.get_local_path(str(filepath)) as local_path: + yield local_path + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, + suffix, recursive) diff --git a/mmcv/fileio/handlers/base.py b/mmcv/fileio/handlers/base.py index 235727557c..5f28b0acc6 100644 --- a/mmcv/fileio/handlers/base.py +++ b/mmcv/fileio/handlers/base.py @@ -3,6 +3,12 @@ class BaseFileHandler(metaclass=ABCMeta): + # `str_like` is a flag to indicate whether the type of file object is + # str-like object or bytes-like object. Pickle only processes bytes-like + # objects but json only processes str-like object. If it is str-like + # object, `StringIO` will be used to process the buffer. + + str_like = True @abstractmethod def load_from_fileobj(self, file, **kwargs): diff --git a/mmcv/fileio/handlers/pickle_handler.py b/mmcv/fileio/handlers/pickle_handler.py index 0250459957..b37c79bed4 100644 --- a/mmcv/fileio/handlers/pickle_handler.py +++ b/mmcv/fileio/handlers/pickle_handler.py @@ -6,6 +6,8 @@ class PickleHandler(BaseFileHandler): + str_like = False + def load_from_fileobj(self, file, **kwargs): return pickle.load(file, **kwargs) diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 015d36e808..aaefde58aa 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +from io import BytesIO, StringIO from pathlib import Path from ..utils import is_list_of, is_str +from .file_client import FileClient from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler file_handlers = { @@ -13,11 +15,15 @@ } -def load(file, file_format=None, **kwargs): +def load(file, file_format=None, file_client_args=None, **kwargs): """Load data from json/yaml/pickle files. This method provides a unified api for loading data from serialized files. + Note: + In v1.3.16 and later, ``load`` supports loading data from serialized + files those can be storaged in different backends. + Args: file (str or :obj:`Path` or file-like object): Filename or a file-like object. @@ -25,6 +31,14 @@ def load(file, file_format=None, **kwargs): inferred from the file extension, otherwise use the specified one. Currently supported formats include "json", "yaml/yml" and "pickle/pkl". + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('s3://path/of/your/file') # file is storaged in petrel Returns: The content from the file. @@ -38,7 +52,13 @@ def load(file, file_format=None, **kwargs): handler = file_handlers[file_format] if is_str(file): - obj = handler.load_from_path(file, **kwargs) + file_client = FileClient.infer_client(file_client_args, file) + if handler.str_like: + with StringIO(file_client.get_text(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + with BytesIO(file_client.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) elif hasattr(file, 'read'): obj = handler.load_from_fileobj(file, **kwargs) else: @@ -46,18 +66,29 @@ def load(file, file_format=None, **kwargs): return obj -def dump(obj, file=None, file_format=None, **kwargs): +def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): """Dump data to json/yaml/pickle strings or files. This method provides a unified api for dumping data as strings or to files, and also supports custom arguments for each file format. + Note: + In v1.3.16 and later, ``dump`` supports dumping data as strings or to + files which is saved to different backends. + Args: obj (any): The python object to be dumped. file (str or :obj:`Path` or file-like object, optional): If not - specified, then the object is dump to a str, otherwise to a file + specified, then the object is dumped to a str, otherwise to a file specified by the filename or file-like object. file_format (str, optional): Same as :func:`load`. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> dump('hello world', '/path/of/your/file') # disk + >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel Returns: bool: True for success, False otherwise. @@ -77,7 +108,15 @@ def dump(obj, file=None, file_format=None, **kwargs): if file is None: return handler.dump_to_str(obj, **kwargs) elif is_str(file): - handler.dump_to_path(obj, file, **kwargs) + file_client = FileClient.infer_client(file_client_args, file) + if handler.str_like: + with StringIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_client.put_text(f.getvalue(), file) + else: + with BytesIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_client.put(f.getvalue(), file) elif hasattr(file, 'write'): handler.dump_to_fileobj(obj, file, **kwargs) else: diff --git a/mmcv/fileio/parse.py b/mmcv/fileio/parse.py index 987c9f1104..f60f0d611b 100644 --- a/mmcv/fileio/parse.py +++ b/mmcv/fileio/parse.py @@ -1,7 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. -def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'): + +from io import StringIO + +from .file_client import FileClient + + +def list_from_file(filename, + prefix='', + offset=0, + max_num=0, + encoding='utf-8', + file_client_args=None): """Load a text file and parse the content as a list of strings. + Note: + In v1.3.16 and later, ``list_from_file`` supports loading a text file + which can be storaged in different backends and parsing the content as + a list for strings. + Args: filename (str): Filename. prefix (str): The prefix to be inserted to the beginning of each item. @@ -9,13 +25,23 @@ def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'): max_num (int): The maximum number of lines to be read, zeros and negatives mean no limitation. encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> list_from_file('/path/of/your/file') # disk + ['hello', 'world'] + >>> list_from_file('s3://path/of/your/file') # ceph or petrel + ['hello', 'world'] Returns: list[str]: A list of strings. """ cnt = 0 item_list = [] - with open(filename, 'r', encoding=encoding) as f: + file_client = FileClient.infer_client(file_client_args, filename) + with StringIO(file_client.get_text(filename, encoding)) as f: for _ in range(offset): f.readline() for line in f: @@ -26,23 +52,42 @@ def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'): return item_list -def dict_from_file(filename, key_type=str): +def dict_from_file(filename, + key_type=str, + encoding='utf-8', + file_client_args=None): """Load a text file and parse the content as a dict. Each line of the text file will be two or more columns split by whitespaces or tabs. The first column will be parsed as dict keys, and the following columns will be parsed as dict values. + Note: + In v1.3.16 and later, ``dict_from_file`` supports loading a text file + which can be storaged in different backends and parsing the content as + a dict. + Args: filename(str): Filename. key_type(type): Type of the dict keys. str is user by default and type conversion will be performed if specified. + encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> dict_from_file('/path/of/your/file') # disk + {'key1': 'value1', 'key2': 'value2'} + >>> dict_from_file('s3://path/of/your/file') # ceph or petrel + {'key1': 'value1', 'key2': 'value2'} Returns: dict: The parsed contents. """ mapping = {} - with open(filename, 'r') as f: + file_client = FileClient.infer_client(file_client_args, filename) + with StringIO(file_client.get_text(filename, encoding)) as f: for line in f: items = line.rstrip('\n').split() assert len(items) >= 2 diff --git a/mmcv/image/photometric.py b/mmcv/image/photometric.py index 3c1f68f1f5..5085d01201 100644 --- a/mmcv/image/photometric.py +++ b/mmcv/image/photometric.py @@ -309,9 +309,9 @@ def adjust_sharpness(img, factor=1., kernel=None): kernel (np.ndarray, optional): Filter kernel to be applied on the img to obtain the degenerated img. Defaults to None. - Notes:: + Note: No value sanity check is enforced on the kernel set by users. So with - an inappropriate kernel, the `adjust_sharpness` may fail to perform + an inappropriate kernel, the ``adjust_sharpness`` may fail to perform the function its name indicates but end up performing whatever transform determined by the kernel. diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index baf8109f05..378a006843 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -2,7 +2,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .config import Config, ConfigDict, DictAction from .misc import (check_prerequisites, concat_list, deprecated_api_warning, - import_modules_from_strings, is_list_of, + has_method, import_modules_from_strings, is_list_of, is_method_overridden, is_seq_of, is_str, is_tuple_of, iter_cast, list_cast, requires_executable, requires_package, slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, @@ -33,7 +33,7 @@ 'assert_dict_contains_subset', 'assert_attrs_equal', 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script', 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', - 'is_method_overridden' + 'is_method_overridden', 'has_method' ] else: from .env import collect_env @@ -65,5 +65,5 @@ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', - '_get_cuda_home' + '_get_cuda_home', 'has_method' ] diff --git a/mmcv/utils/misc.py b/mmcv/utils/misc.py index 488cbdd81f..2c58d0d7fe 100644 --- a/mmcv/utils/misc.py +++ b/mmcv/utils/misc.py @@ -362,3 +362,16 @@ def is_method_overridden(method, base_class, derived_class): base_method = getattr(base_class, method) derived_method = getattr(derived_class, method) return derived_method != base_method + + +def has_method(obj: object, method: str) -> bool: + """Check whether the object has a method. + + Args: + method (str): The method name to check. + obj (object): The object to check. + + Returns: + bool: True if the object has the method else False. + """ + return hasattr(obj, method) and callable(getattr(obj, method)) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 80357cf31d..d15483c94c 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -1,4 +1,9 @@ +import os +import os.path as osp import sys +import tempfile +from contextlib import contextmanager +from copy import deepcopy from pathlib import Path from unittest.mock import MagicMock, patch @@ -6,6 +11,7 @@ import mmcv from mmcv import BaseStorageBackend, FileClient +from mmcv.utils import has_method sys.modules['ceph'] = MagicMock() sys.modules['petrel_client'] = MagicMock() @@ -13,6 +19,51 @@ sys.modules['mc'] = MagicMock() +@contextmanager +def build_temporary_directory(): + """Build a temporary directory containing many files to test + ``FileClient.list_dir_or_file``. + + . \n + | -- dir1 \n + | -- | -- text3.txt \n + | -- dir2 \n + | -- | -- dir3 \n + | -- | -- | -- text4.txt \n + | -- | -- img.jpg \n + | -- text1.txt \n + | -- text2.txt \n + """ + with tempfile.TemporaryDirectory() as tmp_dir: + text1 = Path(tmp_dir) / 'text1.txt' + text1.open('w').write('text1') + text2 = Path(tmp_dir) / 'text2.txt' + text2.open('w').write('text2') + dir1 = Path(tmp_dir) / 'dir1' + dir1.mkdir() + text3 = dir1 / 'text3.txt' + text3.open('w').write('text3') + dir2 = Path(tmp_dir) / 'dir2' + dir2.mkdir() + jpg1 = dir2 / 'img.jpg' + jpg1.open('wb').write(b'img') + dir3 = dir2 / 'dir3' + dir3.mkdir() + text4 = dir3 / 'text4.txt' + text4.open('w').write('text4') + yield tmp_dir + + +@contextmanager +def delete_and_reset_method(obj, method): + method_obj = deepcopy(getattr(type(obj), method)) + try: + delattr(type(obj), method) + yield + finally: + setattr(type(obj), method, method_obj) + + class MockS3Client: def __init__(self, enable_mc=True): @@ -24,6 +75,37 @@ def Get(self, filepath): return content +class MockPetrelClient: + + def __init__(self, enable_mc=True, enable_multi_cluster=False): + self.enable_mc = enable_mc + self.enable_multi_cluster = enable_multi_cluster + + def Get(self, filepath): + with open(filepath, 'rb') as f: + content = f.read() + return content + + def put(self): + pass + + def delete(self): + pass + + def contains(self): + pass + + def isdir(self): + pass + + def list(self, dir_path): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + yield entry.name + elif osp.isdir(entry.path): + yield entry.name + '/' + + class MockMemcachedClient: def __init__(self, server_list_cfg, client_cfg): @@ -50,6 +132,7 @@ def test_error(self): def test_disk_backend(self): disk_backend = FileClient('disk') + # test `get` # input path is Path object img_bytes = disk_backend.get(self.img_path) img = mmcv.imfrombytes(img_bytes) @@ -61,6 +144,7 @@ def test_disk_backend(self): assert self.img_path.open('rb').read() == img_bytes assert img.shape == self.img_shape + # test `get_text` # input path is Path object value_buf = disk_backend.get_text(self.text_path) assert self.text_path.open('r').read() == value_buf @@ -68,6 +152,118 @@ def test_disk_backend(self): value_buf = disk_backend.get_text(str(self.text_path)) assert self.text_path.open('r').read() == value_buf + with tempfile.TemporaryDirectory() as tmp_dir: + # test `put` + filepath1 = Path(tmp_dir) / 'test.jpg' + disk_backend.put(b'disk', filepath1) + assert filepath1.open('rb').read() == b'disk' + + # test `put_text` + filepath2 = Path(tmp_dir) / 'test.txt' + disk_backend.put_text('disk', filepath2) + assert filepath2.open('r').read() == 'disk' + + # test `isfile` + assert disk_backend.isfile(filepath2) + assert not disk_backend.isfile(Path(tmp_dir) / 'not/existed/path') + + # test `remove` + disk_backend.remove(filepath2) + + # test `exists` + assert not disk_backend.exists(filepath2) + + # test `get_local_path` + # if the backend is disk, `get_local_path` just return the input + with disk_backend.get_local_path(filepath1) as path: + assert str(filepath1) == path + assert osp.isfile(filepath1) + + # test `concat_paths` + disk_dir = '/path/of/your/directory' + assert disk_backend.concat_paths(disk_dir, 'file') == \ + osp.join(disk_dir, 'file') + assert disk_backend.concat_paths(disk_dir, 'dir', 'file') == \ + osp.join(disk_dir, 'dir', 'file') + + # test `list_dir_or_file` + with build_temporary_directory() as tmp_dir: + # 1. list directories and files + assert set(disk_backend.list_dir_or_file(tmp_dir)) == set( + ['dir1', 'dir2', 'text1.txt', 'text2.txt']) + # 2. list directories and files recursively + assert set(disk_backend.list_dir_or_file( + tmp_dir, recursive=True)) == set([ + 'dir1', + osp.join('dir1', 'text3.txt'), 'dir2', + osp.join('dir2', 'dir3'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + ]) + # 3. only list directories + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_file=False)) == set(['dir1', 'dir2']) + with pytest.raises( + TypeError, + match='`suffix` should be None when `list_dir` is True'): + # Exception is raised among the `list_dir_or_file` of client, + # so we need to invode the client to trigger the exception + disk_backend.client.list_dir_or_file( + tmp_dir, list_file=False, suffix='.txt') + # 4. only list directories recursively + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_file=False, recursive=True)) == set( + ['dir1', 'dir2', + osp.join('dir2', 'dir3')]) + # 5. only list files + assert set(disk_backend.list_dir_or_file( + tmp_dir, list_dir=False)) == set(['text1.txt', 'text2.txt']) + # 6. only list files recursively + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_dir=False, recursive=True)) == set([ + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + ]) + # 7. only list files ending with suffix + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_dir=False, + suffix='.txt')) == set(['text1.txt', 'text2.txt']) + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_dir=False, + suffix=('.txt', + '.jpg'))) == set(['text1.txt', 'text2.txt']) + with pytest.raises( + TypeError, + match='`suffix` must be a string or tuple of strings'): + disk_backend.client.list_dir_or_file( + tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + # 8. only list files ending with suffix recursively + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_dir=False, suffix='.txt', + recursive=True)) == set([ + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', + 'text2.txt' + ]) + # 7. only list files ending with suffix + assert set( + disk_backend.list_dir_or_file( + tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)) == set([ + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + ]) + @patch('ceph.S3Client', MockS3Client) def test_ceph_backend(self): ceph_backend = FileClient('ceph') @@ -103,16 +299,11 @@ def test_ceph_backend(self): ceph_backend.client._client.Get.assert_called_with( str(self.img_path).replace(str(self.test_data_dir), ceph_path)) - @patch('petrel_client.client.Client', MockS3Client) - def test_petrel_backend(self): - petrel_backend = FileClient('petrel') - - # input path is Path object - with pytest.raises(NotImplementedError): - petrel_backend.get_text(self.text_path) - # input path is str - with pytest.raises(NotImplementedError): - petrel_backend.get_text(str(self.text_path)) + @patch('petrel_client.client.Client', MockPetrelClient) + @pytest.mark.parametrize('backend,prefix', [('petrel', None), + (None, 's3')]) + def test_petrel_backend(self, backend, prefix): + petrel_backend = FileClient(backend=backend, prefix=prefix) # input path is Path object img_bytes = petrel_backend.get(self.img_path) @@ -126,17 +317,209 @@ def test_petrel_backend(self): # `path_mapping` is either None or dict with pytest.raises(AssertionError): FileClient('petrel', path_mapping=1) - # test `path_mapping` - petrel_path = 's3://user/data' + + # test `_map_path` + petrel_dir = 's3://user/data' petrel_backend = FileClient( - 'petrel', path_mapping={str(self.test_data_dir): petrel_path}) - petrel_backend.client._client.Get = MagicMock( - return_value=petrel_backend.client._client.Get(self.img_path)) - img_bytes = petrel_backend.get(self.img_path) - img = mmcv.imfrombytes(img_bytes) - assert img.shape == self.img_shape - petrel_backend.client._client.Get.assert_called_with( - str(self.img_path).replace(str(self.test_data_dir), petrel_path)) + 'petrel', path_mapping={str(self.test_data_dir): petrel_dir}) + assert petrel_backend.client._map_path(str(self.img_path)) == \ + str(self.img_path).replace(str(self.test_data_dir), petrel_dir) + + petrel_path = f'{petrel_dir}/test.jpg' + petrel_backend = FileClient('petrel') + + # test `_format_path` + assert petrel_backend.client._format_path('s3://user\\data\\test.jpg')\ + == petrel_path + + # test `get` + with patch.object( + petrel_backend.client._client, 'Get', + return_value=b'petrel') as mock_get: + assert petrel_backend.get(petrel_path) == b'petrel' + mock_get.assert_called_once_with(petrel_path) + + # test `get_text` + with patch.object( + petrel_backend.client._client, 'Get', + return_value=b'petrel') as mock_get: + assert petrel_backend.get_text(petrel_path) == 'petrel' + mock_get.assert_called_once_with(petrel_path) + + # test `put` + with patch.object(petrel_backend.client._client, 'put') as mock_put: + petrel_backend.put(b'petrel', petrel_path) + mock_put.assert_called_once_with(petrel_path, b'petrel') + + # test `put_text` + with patch.object(petrel_backend.client._client, 'put') as mock_put: + petrel_backend.put_text('petrel', petrel_path) + mock_put.assert_called_once_with(petrel_path, b'petrel') + + # test `remove` + assert has_method(petrel_backend.client._client, 'delete') + # raise Exception if `delete` is not implemented + with delete_and_reset_method(petrel_backend.client._client, 'delete'): + assert not has_method(petrel_backend.client._client, 'delete') + with pytest.raises(NotImplementedError): + petrel_backend.remove(petrel_path) + + with patch.object(petrel_backend.client._client, + 'delete') as mock_delete: + petrel_backend.remove(petrel_path) + mock_delete.assert_called_once_with(petrel_path) + + # test `exists` + assert has_method(petrel_backend.client._client, 'contains') + assert has_method(petrel_backend.client._client, 'isdir') + # raise Exception if `delete` is not implemented + with delete_and_reset_method(petrel_backend.client._client, + 'contains'), delete_and_reset_method( + petrel_backend.client._client, + 'isdir'): + assert not has_method(petrel_backend.client._client, 'contains') + assert not has_method(petrel_backend.client._client, 'isdir') + with pytest.raises(NotImplementedError): + petrel_backend.exists(petrel_path) + + with patch.object( + petrel_backend.client._client, 'contains', + return_value=True) as mock_contains: + assert petrel_backend.exists(petrel_path) + mock_contains.assert_called_once_with(petrel_path) + + # test `isdir` + assert has_method(petrel_backend.client._client, 'isdir') + with delete_and_reset_method(petrel_backend.client._client, 'isdir'): + assert not has_method(petrel_backend.client._client, 'isdir') + with pytest.raises(NotImplementedError): + petrel_backend.isdir(petrel_path) + + with patch.object( + petrel_backend.client._client, 'isdir', + return_value=True) as mock_isdir: + assert petrel_backend.isdir(petrel_dir) + mock_isdir.assert_called_once_with(petrel_dir) + + # test `isfile` + assert has_method(petrel_backend.client._client, 'contains') + with delete_and_reset_method(petrel_backend.client._client, + 'contains'): + assert not has_method(petrel_backend.client._client, 'contains') + with pytest.raises(NotImplementedError): + petrel_backend.isfile(petrel_path) + + with patch.object( + petrel_backend.client._client, 'contains', + return_value=True) as mock_contains: + assert petrel_backend.isfile(petrel_path) + mock_contains.assert_called_once_with(petrel_path) + + # test `concat_paths` + assert petrel_backend.concat_paths(petrel_dir, 'file') == \ + f'{petrel_dir}/file' + assert petrel_backend.concat_paths(f'{petrel_dir}/', 'file') == \ + f'{petrel_dir}/file' + assert petrel_backend.concat_paths(petrel_dir, 'dir', 'file') == \ + f'{petrel_dir}/dir/file' + + # test `get_local_path` + with patch.object(petrel_backend.client._client, 'Get', + return_value=b'petrel') as mock_get, \ + patch.object(petrel_backend.client._client, 'contains', + return_value=True) as mock_contains: + with petrel_backend.get_local_path(petrel_path) as path: + assert Path(path).open('rb').read() == b'petrel' + # exist the with block and path will be released + assert not osp.isfile(path) + mock_get.assert_called_once_with(petrel_path) + mock_contains.assert_called_once_with(petrel_path) + + # test `list_dir_or_file` + assert has_method(petrel_backend.client._client, 'list') + with delete_and_reset_method(petrel_backend.client._client, 'list'): + assert not has_method(petrel_backend.client._client, 'list') + with pytest.raises(NotImplementedError): + list(petrel_backend.list_dir_or_file(petrel_dir)) + + with build_temporary_directory() as tmp_dir: + # 1. list directories and files + assert set(petrel_backend.list_dir_or_file(tmp_dir)) == set( + ['dir1', 'dir2', 'text1.txt', 'text2.txt']) + # 2. list directories and files recursively + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, recursive=True)) == set([ + 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', + '/'.join(('dir2', 'dir3')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + ]) + # 3. only list directories + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_file=False)) == set(['dir1', 'dir2']) + with pytest.raises( + TypeError, + match=('`list_dir` should be False when `suffix` is not ' + 'None')): + # Exception is raised among the `list_dir_or_file` of client, + # so we need to invode the client to trigger the exception + petrel_backend.client.list_dir_or_file( + tmp_dir, list_file=False, suffix='.txt') + # 4. only list directories recursively + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_file=False, recursive=True)) == set( + ['dir1', 'dir2', '/'.join(('dir2', 'dir3'))]) + # 5. only list files + assert set( + petrel_backend.list_dir_or_file(tmp_dir, + list_dir=False)) == set( + ['text1.txt', 'text2.txt']) + # 6. only list files recursively + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_dir=False, recursive=True)) == set([ + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + ]) + # 7. only list files ending with suffix + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_dir=False, + suffix='.txt')) == set(['text1.txt', 'text2.txt']) + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_dir=False, + suffix=('.txt', + '.jpg'))) == set(['text1.txt', 'text2.txt']) + with pytest.raises( + TypeError, + match='`suffix` must be a string or tuple of strings'): + petrel_backend.client.list_dir_or_file( + tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + # 8. only list files ending with suffix recursively + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_dir=False, suffix='.txt', + recursive=True)) == set([ + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), 'text1.txt', + 'text2.txt' + ]) + # 7. only list files ending with suffix + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)) == set([ + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + ]) @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.pyvector', MagicMock) @@ -182,8 +565,10 @@ def test_lmdb_backend(self): img = mmcv.imfrombytes(img_bytes) assert img.shape == (120, 125, 3) - def test_http_backend(self): - http_backend = FileClient('http') + @pytest.mark.parametrize('backend,prefix', [('http', None), + (None, 'http')]) + def test_http_backend(self, backend, prefix): + http_backend = FileClient(backend=backend, prefix=prefix) img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ 'master/tests/data/color.jpg' text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ @@ -208,6 +593,84 @@ def test_http_backend(self): value_buf = http_backend.get_text(text_url) assert self.text_path.open('r').read() == value_buf + # test `_get_local_path` + # exist the with block and path will be released + with http_backend.get_local_path(img_url) as path: + assert mmcv.imread(path).shape == self.img_shape + assert not osp.isfile(path) + + def test_new_magic_method(self): + + class DummyBackend1(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath, encoding='utf-8'): + return filepath + + FileClient.register_backend('dummy_backend', DummyBackend1) + client1 = FileClient(backend='dummy_backend') + client2 = FileClient(backend='dummy_backend') + assert client1 is client2 + + # if a backend is overwrote, it will disable the singleton pattern for + # the backend + class DummyBackend2(BaseStorageBackend): + + def get(self, filepath): + pass + + def get_text(self, filepath): + pass + + FileClient.register_backend('dummy_backend', DummyBackend2, force=True) + client3 = FileClient(backend='dummy_backend') + client4 = FileClient(backend='dummy_backend') + assert client3 is not client4 + + def test_parse_uri_prefix(self): + # input path is None + with pytest.raises(AssertionError): + FileClient.parse_uri_prefix(None) + # input path is list + with pytest.raises(AssertionError): + FileClient.parse_uri_prefix([]) + + # input path is Path object + assert FileClient.parse_uri_prefix(self.img_path) is None + # input path is str + assert FileClient.parse_uri_prefix(str(self.img_path)) is None + + # input path starts with https + img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ + 'master/tests/data/color.jpg' + assert FileClient.parse_uri_prefix(img_url) == 'https' + + # input path starts with s3 + img_url = 's3://your_bucket/img.png' + assert FileClient.parse_uri_prefix(img_url) == 's3' + + # input path starts with clusterName:s3 + img_url = 'clusterName:s3://your_bucket/img.png' + assert FileClient.parse_uri_prefix(img_url) == 's3' + + def test_infer_client(self): + # HardDiskBackend + file_client_args = {'backend': 'disk'} + client = FileClient.infer_client(file_client_args) + assert client.backend_name == 'disk' + client = FileClient.infer_client(uri=self.img_path) + assert client.backend_name == 'disk' + + # PetrelBackend + file_client_args = {'backend': 'petrel'} + client = FileClient.infer_client(file_client_args) + assert client.backend_name == 'petrel' + uri = 's3://user_data' + client = FileClient.infer_client(uri=uri) + assert client.backend_name == 'petrel' + def test_register_backend(self): # name must be a string @@ -235,7 +698,7 @@ class ExampleBackend(BaseStorageBackend): def get(self, filepath): return filepath - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return filepath FileClient.register_backend('example', ExampleBackend) @@ -247,9 +710,9 @@ def get_text(self, filepath): class Example2Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes2' + return b'bytes2' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text2' # force=False @@ -258,20 +721,20 @@ def get_text(self, filepath): FileClient.register_backend('example', Example2Backend, force=True) example_backend = FileClient('example') - assert example_backend.get(self.img_path) == 'bytes2' + assert example_backend.get(self.img_path) == b'bytes2' assert example_backend.get_text(self.text_path) == 'text2' @FileClient.register_backend(name='example3') class Example3Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes3' + return b'bytes3' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text3' example_backend = FileClient('example3') - assert example_backend.get(self.img_path) == 'bytes3' + assert example_backend.get(self.img_path) == b'bytes3' assert example_backend.get_text(self.text_path) == 'text3' assert 'example3' in FileClient._backends @@ -282,20 +745,89 @@ def get_text(self, filepath): class Example4Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes4' + return b'bytes4' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text4' @FileClient.register_backend(name='example3', force=True) class Example5Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes5' + return b'bytes5' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text5' example_backend = FileClient('example3') - assert example_backend.get(self.img_path) == 'bytes5' + assert example_backend.get(self.img_path) == b'bytes5' assert example_backend.get_text(self.text_path) == 'text5' + + # prefixes is a str + class Example6Backend(BaseStorageBackend): + + def get(self, filepath): + return b'bytes6' + + def get_text(self, filepath, encoding='utf-8'): + return 'text6' + + FileClient.register_backend( + 'example4', + Example6Backend, + force=True, + prefixes='example4_prefix') + example_backend = FileClient('example4') + assert example_backend.get(self.img_path) == b'bytes6' + assert example_backend.get_text(self.text_path) == 'text6' + example_backend = FileClient(prefix='example4_prefix') + assert example_backend.get(self.img_path) == b'bytes6' + assert example_backend.get_text(self.text_path) == 'text6' + example_backend = FileClient('example4', prefix='example4_prefix') + assert example_backend.get(self.img_path) == b'bytes6' + assert example_backend.get_text(self.text_path) == 'text6' + + # prefixes is a list of str + class Example7Backend(BaseStorageBackend): + + def get(self, filepath): + return b'bytes7' + + def get_text(self, filepath, encoding='utf-8'): + return 'text7' + + FileClient.register_backend( + 'example5', + Example7Backend, + force=True, + prefixes=['example5_prefix1', 'example5_prefix2']) + example_backend = FileClient('example5') + assert example_backend.get(self.img_path) == b'bytes7' + assert example_backend.get_text(self.text_path) == 'text7' + example_backend = FileClient(prefix='example5_prefix1') + assert example_backend.get(self.img_path) == b'bytes7' + assert example_backend.get_text(self.text_path) == 'text7' + example_backend = FileClient(prefix='example5_prefix2') + assert example_backend.get(self.img_path) == b'bytes7' + assert example_backend.get_text(self.text_path) == 'text7' + + # backend has a higher priority than prefixes + class Example8Backend(BaseStorageBackend): + + def get(self, filepath): + return b'bytes8' + + def get_text(self, filepath, encoding='utf-8'): + return 'text8' + + FileClient.register_backend( + 'example6', + Example8Backend, + force=True, + prefixes='example6_prefix') + example_backend = FileClient('example6') + assert example_backend.get(self.img_path) == b'bytes8' + assert example_backend.get_text(self.text_path) == 'text8' + example_backend = FileClient('example6', prefix='example4_prefix') + assert example_backend.get(self.img_path) == b'bytes8' + assert example_backend.get_text(self.text_path) == 'text8' diff --git a/tests/test_fileio.py b/tests/test_fileio.py index a9d70f515a..556a44a133 100644 --- a/tests/test_fileio.py +++ b/tests/test_fileio.py @@ -1,11 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp +import sys import tempfile +from unittest.mock import MagicMock, patch import pytest import mmcv +from mmcv.fileio.file_client import HTTPBackend, PetrelBackend + +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() def _test_handler(file_format, test_obj, str_checker, mode='r+'): @@ -13,7 +19,7 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): dump_str = mmcv.dump(test_obj, file_format=file_format) str_checker(dump_str) - # load/dump with filenames + # load/dump with filenames from disk tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test_dump') mmcv.dump(test_obj, tmp_filename, file_format=file_format) assert osp.isfile(tmp_filename) @@ -21,6 +27,13 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): assert load_obj == test_obj os.remove(tmp_filename) + # load/dump with filename from petrel + method = 'put' if 'b' in mode else 'put_text' + with patch.object(PetrelBackend, method, return_value=None) as mock_method: + filename = 's3://path/of/your/file' + mmcv.dump(test_obj, filename, file_format=file_format) + mock_method.assert_called() + # json load/dump with a file-like object with tempfile.NamedTemporaryFile(mode, delete=False) as f: tmp_filename = f.name @@ -122,6 +135,7 @@ def dump_to_str(self, obj, **kwargs): def test_list_from_file(): + # get list from disk filename = osp.join(osp.dirname(__file__), 'data/filelist.txt') filelist = mmcv.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg'] @@ -134,10 +148,64 @@ def test_list_from_file(): filelist = mmcv.list_from_file(filename, offset=3, max_num=3) assert filelist == ['4.jpg', '5.jpg'] + # get list from http + with patch.object( + HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): + filename = 'http://path/of/your/file' + filelist = mmcv.list_from_file( + filename, file_client_args={'backend': 'http'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file( + filename, file_client_args={'prefix': 'http'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file(filename) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + + # get list from petrel + with patch.object( + PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): + filename = 's3://path/of/your/file' + filelist = mmcv.list_from_file( + filename, file_client_args={'backend': 'petrel'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file( + filename, file_client_args={'prefix': 's3'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file(filename) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + def test_dict_from_file(): + # get dict from disk filename = osp.join(osp.dirname(__file__), 'data/mapping.txt') mapping = mmcv.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmcv.dict_from_file(filename, key_type=int) assert mapping == {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} + + # get dict from http + with patch.object( + HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'): + filename = 'http://path/of/your/file' + mapping = mmcv.dict_from_file( + filename, file_client_args={'backend': 'http'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file( + filename, file_client_args={'prefix': 'http'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file(filename) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + + # get dict from petrel + with patch.object( + PetrelBackend, 'get_text', + return_value='1 cat\n2 dog cow\n3 panda'): + filename = 's3://path/of/your/file' + mapping = mmcv.dict_from_file( + filename, file_client_args={'backend': 'petrel'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file( + filename, file_client_args={'prefix': 's3'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file(filename) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 1c7ed5a9ff..6070624c3e 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -3,6 +3,7 @@ import mmcv from mmcv import deprecated_api_warning +from mmcv.utils.misc import has_method def test_to_ntuple(): @@ -193,6 +194,21 @@ def foo1(): mmcv.is_method_overridden('foo1', base_instance, sub_instance) +def test_has_method(): + + class Foo: + + def __init__(self, name): + self.name = name + + def print_name(self): + print(self.name) + + foo = Foo('foo') + assert not has_method(foo, 'name') + assert has_method(foo, 'print_name') + + def test_deprecated_api_warning(): @deprecated_api_warning(name_dict=dict(old_key='new_key')) From bb78be4689358083028d472e0642e1bd9d9a638b Mon Sep 17 00:00:00 2001 From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Date: Sun, 24 Oct 2021 13:32:31 +0800 Subject: [PATCH 07/30] Add CI for pytorch 1.10 (#1431) --- .github/workflows/build.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c97088688d..9c52827ed2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -211,10 +211,12 @@ jobs: strategy: matrix: python-version: [3.7] - torch: [1.9.0+cu102] + torch: [1.9.0+cu102, 1.10.0+cu102] include: - torch: 1.9.0+cu102 torchvision: 0.10.0+cu102 + - torch: 1.10.0+cu102 + torchvision: 0.11.0+cu102 - python-version: 3.6 torch: 1.9.0+cu102 torchvision: 0.10.0+cu102 From 2150d27803265c7a31b03e32d2ed73955035fdc1 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sun, 24 Oct 2021 14:26:52 +0800 Subject: [PATCH 08/30] [Feature] Upload checkpoints and logs to ceph (#1375) * [Feature] Choose storage backend by the prefix of filepath * refactor FileClient and add unittest * support loading from different backends * polish docstring * fix unittet * rename attribute str_like_obj to is_str_like_obj * [Docs] Upload checkpoint to petrel oss * add infer_client method * Support uploading checkpoint to petrel oss * add check_exist method * refactor CheckpointHook * support uploading logs to ceph * rename var client to file_client * polish docstring * enhance load_from_ceph * refactor load_from_ceph * refactor TextLoggerHook * change the meaning of out_dir argument * fix test_checkpoint_hook.py * add join_paths method * remove join_paths and add _format_path * enhance unittest * refactor unittest * add a unittest for EvalHook when file backend is petrel * singleton pattern * fix test_clientio.py * deprecate CephBackend * add warning in load_from_ceph * fix type of out_suffix * enhance docstring * refactor unittest for petrel * refactor unittest for disk backend * update io.md * add concat_paths method * fix CI * mock check_exist * improve docstring * improve docstring * improve docstring * improve docstring * add isdir and copyfile for file backend * delete copyfile and add get_local_path * remove isdir method of petrel * fix typo * rename check_exists to exists * refactor code and polish docstring * fix windows ci * add comment and polish docstring * polish docstring * polish docstring * rename _path_mapping to _map_path * polish docstring and fix typo * refactor get_local_path * add list_dir_or_file for FileClient * add list_dir_or_file for PetrelBackend * fix windows ci * Add return docstring * polish docstring * fix typo * fix typo * fix typo * fix error when mocking PetrelBackend * deprecate the conversion from Path to str * add docs for loading checkpoints with FileClient * rename keep_log to keep_local * refactor map_path * add _ensure_methods to ensure methods have been implemented * fix list_dir_or_file * rename _ensure_method_implemented to has_method * refactor * polish information * format information --- mmcv/fileio/file_client.py | 73 ++++++++++++++++------ mmcv/fileio/handlers/base.py | 1 - mmcv/runner/checkpoint.py | 53 +++++++++++----- mmcv/runner/hooks/checkpoint.py | 71 +++++++++++++++++---- mmcv/runner/hooks/evaluation.py | 55 ++++++++++++++-- mmcv/runner/hooks/logger/text.py | 93 +++++++++++++++++++++++++--- tests/test_fileclient.py | 49 +++++++++++---- tests/test_runner/test_checkpoint.py | 42 ++++++++++++- tests/test_runner/test_eval_hook.py | 33 ++++++++++ tests/test_runner/test_hooks.py | 47 +++++++++++++- 10 files changed, 443 insertions(+), 74 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index a6c0f8b89e..b2d622868c 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -11,6 +11,7 @@ from typing import Iterable, Iterator, Optional, Tuple, Union from urllib.request import urlopen +import mmcv from mmcv.utils.misc import has_method from mmcv.utils.path import is_filepath @@ -23,6 +24,17 @@ class BaseStorageBackend(metaclass=ABCMeta): as texts. """ + # a flag to indicate whether the backend can create a symlink for a file + _allow_symlink = False + + @property + def name(self): + return self.__class__.__name__ + + @property + def allow_symlink(self): + return self._allow_symlink + @abstractmethod def get(self, filepath): pass @@ -41,8 +53,8 @@ class CephBackend(BaseStorageBackend): will be replaced by ``dst``. Default: None. .. warning:: - :class:`CephBackend` will be deprecated, please use - :class:`PetrelBackend` instead + :class:`mmcv.fileio.file_client.CephBackend` will be deprecated, + please use :class:`mmcv.fileio.file_client.PetrelBackend` instead. """ def __init__(self, path_mapping=None): @@ -266,8 +278,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: filepath = self._format_path(filepath) return self._client.contains(filepath) - def concat_paths(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: """Concatenate all file paths. Args: @@ -377,7 +389,7 @@ def _list_dir_or_file(dir_path, list_dir, list_file, suffix, # is a directory, because `self.isdir` relies on # `self._client.list` if path.endswith('/'): # a directory path - next_dir_path = self.concat_paths(dir_path, path) + next_dir_path = self.join_path(dir_path, path) if list_dir: # get the relative path and exclude the last # character '/' @@ -388,7 +400,7 @@ def _list_dir_or_file(dir_path, list_dir, list_file, suffix, list_file, suffix, recursive) else: # a file path - absolute_path = self.concat_paths(dir_path, path) + absolute_path = self.join_path(dir_path, path) rel_path = absolute_path[len(root):] if (suffix is None or rel_path.endswith(suffix)) and list_file: @@ -491,6 +503,8 @@ def get_text(self, filepath, encoding=None): class HardDiskBackend(BaseStorageBackend): """Raw hard disks storage backend.""" + _allow_symlink = True + def get(self, filepath: Union[str, Path]) -> bytes: """Read data from a given ``filepath`` with 'rb' mode. @@ -524,10 +538,15 @@ def get_text(self, def put(self, obj: bytes, filepath: Union[str, Path]) -> None: """Write data to a given ``filepath`` with 'wb' mode. + Note: + ``put`` will create a directory if the directory of ``filepath`` + does not exist. + Args: obj (bytes): Data to be written. filepath (str or Path): Path to write data. """ + mmcv.mkdir_or_exist(osp.dirname(filepath)) with open(filepath, 'wb') as f: f.write(obj) @@ -537,12 +556,17 @@ def put_text(self, encoding: str = 'utf-8') -> None: """Write data to a given ``filepath`` with 'w' mode. + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + Args: obj (str): Data to be written. filepath (str or Path): Path to write data. encoding (str): The encoding format used to open the ``filepath``. Default: 'utf-8'. """ + mmcv.mkdir_or_exist(osp.dirname(filepath)) with open(filepath, 'w', encoding=encoding) as f: f.write(obj) @@ -579,7 +603,7 @@ def isdir(self, filepath: Union[str, Path]) -> bool: return osp.isdir(filepath) def isfile(self, filepath: Union[str, Path]) -> bool: - """Check a ``filepath`` whether it is a file. + """Check whether a file path is a file. Args: filepath (str or Path): Path to be checked whether it is a file. @@ -590,8 +614,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return osp.isfile(filepath) - def concat_paths(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: """Concatenate all file paths. Join one or more filepath components intelligently. The return value @@ -714,7 +738,7 @@ class FileClient: Note that It can also register other backend accessor with a given name, prefixes, and backend class. In addition, We use the singleton pattern to avoid repeated object creation. If the arguments are the same, the same - object is returned. + object will be returned. Args: backend (str, optional): The storage backend type. Options are "disk", @@ -788,18 +812,21 @@ def __new__(cls, backend=None, prefix=None, **kwargs): _instance = super().__new__(cls) if backend is not None: _instance.client = cls._backends[backend](**kwargs) - _instance.backend_name = backend else: _instance.client = cls._prefix_to_backends[prefix](**kwargs) - # infer the backend name according to the prefix - for backend_name, backend_cls in cls._backends.items(): - if isinstance(_instance.client, backend_cls): - _instance.backend_name = backend_name - break + cls._instances[arg_key] = _instance return _instance + @property + def name(self): + return self.client.name + + @property + def allow_symlink(self): + return self.client.allow_symlink + @staticmethod def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: """Parse the prefix of a uri. @@ -980,6 +1007,10 @@ def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str: def put(self, obj: bytes, filepath: Union[str, Path]) -> None: """Write data to a given ``filepath`` with 'wb' mode. + Note: + ``put`` should create a directory if the directory of ``filepath`` + does not exist. + Args: obj (bytes): Data to be written. filepath (str or Path): Path to write data. @@ -989,6 +1020,10 @@ def put(self, obj: bytes, filepath: Union[str, Path]) -> None: def put_text(self, obj: str, filepath: Union[str, Path]) -> None: """Write data to a given ``filepath`` with 'w' mode. + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + Args: obj (str): Data to be written. filepath (str or Path): Path to write data. @@ -1041,8 +1076,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return self.client.isfile(filepath) - def concat_paths(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: """Concatenate all file paths. Join one or more filepath components intelligently. The return value @@ -1054,7 +1089,7 @@ def concat_paths(self, filepath: Union[str, Path], Returns: str: The result of concatenation. """ - return self.client.concat_paths(filepath, *filepaths) + return self.client.join_path(filepath, *filepaths) @contextmanager def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: diff --git a/mmcv/fileio/handlers/base.py b/mmcv/fileio/handlers/base.py index 5f28b0acc6..288878bc57 100644 --- a/mmcv/fileio/handlers/base.py +++ b/mmcv/fileio/handlers/base.py @@ -7,7 +7,6 @@ class BaseFileHandler(metaclass=ABCMeta): # str-like object or bytes-like object. Pickle only processes bytes-like # objects but json only processes str-like object. If it is str-like # object, `StringIO` will be used to process the buffer. - str_like = True @abstractmethod diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index e266608c3a..4db75d23f7 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -323,7 +323,7 @@ def load_from_pavi(filename, map_location=None): @CheckpointLoader.register_scheme(prefixes='s3://') -def load_from_ceph(filename, map_location=None, backend='ceph'): +def load_from_ceph(filename, map_location=None, backend='petrel'): """load checkpoint through the file path prefixed with s3. In distributed setting, this function download ckpt at all ranks to different temporary directories. @@ -331,20 +331,35 @@ def load_from_ceph(filename, map_location=None, backend='ceph'): Args: filename (str): checkpoint file path with s3 prefix map_location (str, optional): Same as :func:`torch.load`. - backend (str): The storage backend type. Options are "disk", "ceph", - "memcached" and "lmdb". Default: 'ceph' + backend (str, optional): The storage backend type. Options are 'ceph', + 'petrel'. Default: 'petrel'. + + .. warning:: + :class:`mmcv.fileio.file_client.CephBackend` will be deprecated, + please use :class:`mmcv.fileio.file_client.PetrelBackend` instead. Returns: dict or OrderedDict: The loaded checkpoint. """ - - allowed_backends = ['ceph'] + allowed_backends = ['ceph', 'petrel'] if backend not in allowed_backends: raise ValueError(f'Load from Backend {backend} is not supported.') - fileclient = FileClient(backend=backend) - buffer = io.BytesIO(fileclient.get(filename)) - checkpoint = torch.load(buffer, map_location=map_location) + if backend == 'ceph': + warnings.warn( + 'CephBackend will be deprecated, please use PetrelBackend instead') + + # CephClient and PetrelBackend have the same prefix 's3://' and the latter + # will be chosen as default. If PetrelBackend can not be instantiated + # successfully, the CephClient will be chosen. + try: + file_client = FileClient(backend=backend) + except ImportError: + allowed_backends.remove(backend) + file_client = FileClient(backend=allowed_backends[0]) + + with io.BytesIO(file_client.get(filename)) as buffer: + checkpoint = torch.load(buffer, map_location=map_location) return checkpoint @@ -506,7 +521,6 @@ def load_checkpoint(model, pair of the regular expression operations. Default: strip the prefix 'module.' by [(r'^module\\.', '')]. - Returns: dict or OrderedDict: The loaded checkpoint. """ @@ -616,7 +630,11 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): return destination -def save_checkpoint(model, filename, optimizer=None, meta=None): +def save_checkpoint(model, + filename, + optimizer=None, + meta=None, + file_client_args=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and @@ -627,6 +645,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` """ if meta is None: meta = {} @@ -654,6 +676,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): checkpoint['optimizer'][name] = optim.state_dict() if filename.startswith('pavi://'): + if file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" if filename starts with' + f'"pavi://", but got {file_client_args}') try: from pavi import modelcloud from pavi import exception @@ -674,8 +700,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): f.flush() model.create_file(checkpoint_file, name=model_name) else: - mmcv.mkdir_or_exist(osp.dirname(filename)) - # immediately flush buffer - with open(filename, 'wb') as f: + file_client = FileClient.infer_client(file_client_args, filename) + with io.BytesIO() as f: torch.save(checkpoint, f) - f.flush() + file_client.put(f.getvalue(), filename) diff --git a/mmcv/runner/hooks/checkpoint.py b/mmcv/runner/hooks/checkpoint.py index d99dcb3e62..7bb75f402a 100644 --- a/mmcv/runner/hooks/checkpoint.py +++ b/mmcv/runner/hooks/checkpoint.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os +import os.path as osp +import warnings +from mmcv.fileio import FileClient from ..dist_utils import allreduce_params, master_only from .hook import HOOKS, Hook @@ -18,16 +20,32 @@ class CheckpointHook(Hook): save_optimizer (bool): Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. Default: True. - out_dir (str, optional): The directory to save checkpoints. If not - specified, ``runner.work_dir`` will be used by default. + out_dir (str, optional): The root directory to save checkpoints. If not + specified, ``runner.work_dir`` will be used by default. If + specified, the ``out_dir`` will be the concatenation of ``out_dir`` + and the last level directory of ``runner.work_dir``. + `Changed in version 1.3.16.` max_keep_ckpts (int, optional): The maximum checkpoints to keep. In some cases we want only the latest few checkpoints and would like to delete old ones to save the disk space. Default: -1, which means unlimited. - save_last (bool): Whether to force the last checkpoint to be saved - regardless of interval. - sync_buffer (bool): Whether to synchronize buffers in different - gpus. Default: False. + save_last (bool, optional): Whether to force the last checkpoint to be + saved regardless of interval. Default: True. + sync_buffer (bool, optional): Whether to synchronize buffers in + different gpus. Default: False. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` + + .. warning:: + Before v1.3.16, the ``out_dir`` argument indicates the path where the + checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the + root directory and the final path to save checkpoint is the + concatenation of ``out_dir`` and the last level directory of + ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A" + and the value of ``runner.work_dir`` is "/path/of/B", then the final + path will be "/path/of/A/B". """ def __init__(self, @@ -38,6 +56,7 @@ def __init__(self, max_keep_ckpts=-1, save_last=True, sync_buffer=False, + file_client_args=None, **kwargs): self.interval = interval self.by_epoch = by_epoch @@ -47,11 +66,39 @@ def __init__(self, self.save_last = save_last self.args = kwargs self.sync_buffer = sync_buffer + self.file_client_args = file_client_args def before_run(self, runner): if not self.out_dir: self.out_dir = runner.work_dir + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + + # if `self.out_dir` is not equal to `runner.work_dir`, it means that + # `self.out_dir` is set so the final `self.out_dir` is the + # concatenation of `self.out_dir` and the last level directory of + # `runner.work_dir` + if self.out_dir != runner.work_dir: + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + + runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by ' + f'{self.file_client.name}.')) + + # disable the create_symlink option because some file backends do not + # allow to create a symlink + if 'create_symlink' in self.args: + if self.args[ + 'create_symlink'] and not self.file_client.allow_symlink: + self.args['create_symlink'] = False + warnings.warn( + ('create_symlink is set as True by the user but is changed' + 'to be False because creating symbolic link is not ' + f'allowed in {self.file_client.name}')) + else: + self.args['create_symlink'] = self.file_client.allow_symlink + def after_train_epoch(self, runner): if not self.by_epoch: return @@ -81,7 +128,7 @@ def _save_checkpoint(self, runner): cur_ckpt_filename = self.args.get( 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) runner.meta.setdefault('hook_msgs', dict()) - runner.meta['hook_msgs']['last_ckpt'] = os.path.join( + runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( self.out_dir, cur_ckpt_filename) # remove other checkpoints if self.max_keep_ckpts > 0: @@ -96,10 +143,10 @@ def _save_checkpoint(self, runner): -self.interval) filename_tmpl = self.args.get('filename_tmpl', name) for _step in redundant_ckpts: - ckpt_path = os.path.join(self.out_dir, - filename_tmpl.format(_step)) - if os.path.exists(ckpt_path): - os.remove(ckpt_path) + ckpt_path = self.file_client.join_path( + self.out_dir, filename_tmpl.format(_step)) + if self.file_client.isfile(ckpt_path): + self.file_client.remove(ckpt_path) else: break diff --git a/mmcv/runner/hooks/evaluation.py b/mmcv/runner/hooks/evaluation.py index 7d2141d3b2..e0ccf3f732 100644 --- a/mmcv/runner/hooks/evaluation.py +++ b/mmcv/runner/hooks/evaluation.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os import os.path as osp import warnings from math import inf @@ -8,6 +7,7 @@ from torch.nn.modules.batchnorm import _BatchNorm from torch.utils.data import DataLoader +from mmcv.fileio import FileClient from mmcv.utils import is_seq_of from .hook import Hook from .logger import LoggerHook @@ -54,6 +54,14 @@ class EvalHook(Hook): less_keys (List[str] | None, optional): Metric keys that will be inferred by 'less' comparison rule. If ``None``, _default_less_keys will be used. (default: ``None``) + out_dir (str, optional): The root directory to save checkpoints. If not + specified, `runner.work_dir` will be used by default. If specified, + the `out_dir` will be the concatenation of `out_dir` and the last + level directory of `runner.work_dir`. + `New in version 1.3.16.` + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. + `New in version 1.3.16.` **eval_kwargs: Evaluation arguments fed into the evaluate function of the dataset. @@ -84,6 +92,8 @@ def __init__(self, test_fn=None, greater_keys=None, less_keys=None, + out_dir=None, + file_client_args=None, **eval_kwargs): if not isinstance(dataloader, DataLoader): raise TypeError(f'dataloader must be a pytorch DataLoader, ' @@ -137,6 +147,9 @@ def __init__(self, self.best_ckpt_path = None self._init_rule(rule, self.save_best) + self.out_dir = out_dir + self.file_client_args = file_client_args + def _init_rule(self, rule, key_indicator): """Initialize rule, key_indicator, comparison_func, and best score. @@ -187,6 +200,23 @@ def _init_rule(self, rule, key_indicator): self.compare_func = self.rule_map[self.rule] def before_run(self, runner): + if not self.out_dir: + self.out_dir = runner.work_dir + + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + + # if `self.out_dir` is not equal to `runner.work_dir`, it means that + # `self.out_dir` is set so the final `self.out_dir` is the + # concatenation of `self.out_dir` and the last level directory of + # `runner.work_dir` + if self.out_dir != runner.work_dir: + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + runner.logger.info( + (f'The best checkpoint will be saved to {self.out_dir} by ' + f'{self.file_client.name}')) + if self.save_best is not None: if runner.meta is None: warnings.warn('runner.meta is None. Creating an empty one.') @@ -299,15 +329,20 @@ def _save_ckpt(self, runner, key_score): best_score = key_score runner.meta['hook_msgs']['best_score'] = best_score - if self.best_ckpt_path and osp.isfile(self.best_ckpt_path): - os.remove(self.best_ckpt_path) + if self.best_ckpt_path and self.file_client.isfile( + self.best_ckpt_path): + self.file_client.remove(self.best_ckpt_path) + runner.logger.info( + (f'The previous best checkpoint {self.best_ckpt_path} was ' + 'removed')) best_ckpt_name = f'best_{self.key_indicator}_{current}.pth' - self.best_ckpt_path = osp.join(runner.work_dir, best_ckpt_name) + self.best_ckpt_path = self.file_client.join_path( + self.out_dir, best_ckpt_name) runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path runner.save_checkpoint( - runner.work_dir, best_ckpt_name, create_symlink=False) + self.out_dir, best_ckpt_name, create_symlink=False) runner.logger.info( f'Now best checkpoint is saved as {best_ckpt_name}.') runner.logger.info( @@ -378,6 +413,12 @@ class DistEvalHook(EvalHook): broadcast_bn_buffer (bool): Whether to broadcast the buffer(running_mean and running_var) of rank 0 to other rank before evaluation. Default: True. + out_dir (str, optional): The root directory to save checkpoints. If not + specified, `runner.work_dir` will be used by default. If specified, + the `out_dir` will be the concatenation of `out_dir` and the last + level directory of `runner.work_dir`. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. **eval_kwargs: Evaluation arguments fed into the evaluate function of the dataset. """ @@ -395,6 +436,8 @@ def __init__(self, broadcast_bn_buffer=True, tmpdir=None, gpu_collect=False, + out_dir=None, + file_client_args=None, **eval_kwargs): if test_fn is None: @@ -411,6 +454,8 @@ def __init__(self, test_fn=test_fn, greater_keys=greater_keys, less_keys=less_keys, + out_dir=out_dir, + file_client_args=file_client_args, **eval_kwargs) self.broadcast_bn_buffer = broadcast_bn_buffer diff --git a/mmcv/runner/hooks/logger/text.py b/mmcv/runner/hooks/logger/text.py index 40a619e5ef..043c7bf20b 100644 --- a/mmcv/runner/hooks/logger/text.py +++ b/mmcv/runner/hooks/logger/text.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import datetime +import os import os.path as osp from collections import OrderedDict @@ -7,6 +8,8 @@ import torch.distributed as dist import mmcv +from mmcv.fileio.file_client import FileClient +from mmcv.utils import is_tuple_of, scandir from ..hook import HOOKS from .base import LoggerHook @@ -19,14 +22,34 @@ class TextLoggerHook(LoggerHook): saved in json file. Args: - by_epoch (bool): Whether EpochBasedRunner is used. - interval (int): Logging interval (every k iterations). - ignore_last (bool): Ignore the log of last iterations in each epoch - if less than `interval`. - reset_flag (bool): Whether to clear the output buffer after logging. - interval_exp_name (int): Logging interval for experiment name. This - feature is to help users conveniently get the experiment + by_epoch (bool, optional): Whether EpochBasedRunner is used. + Default: True. + interval (int, optional): Logging interval (every k iterations). + Default: 10. + ignore_last (bool, optional): Ignore the log of last iterations in each + epoch if less than :attr:`interval`. Default: True. + reset_flag (bool, optional): Whether to clear the output buffer after + logging. Default: False. + interval_exp_name (int, optional): Logging interval for experiment + name. This feature is to help users conveniently get the experiment information from screen or log file. Default: 1000. + out_dir (str, optional): Logs are saved in ``runner.work_dir`` default. + If ``out_dir`` is specified, logs will be copied to a new directory + which is the concatenation of ``out_dir`` and the last level + directory of ``runner.work_dir``. Default: None. + `New in version 1.3.16.` + out_suffix (str or tuple[str], optional): Those filenames ending with + ``out_suffix`` will be copied to ``out_dir``. + Default: ('.log.json', '.log', '.py'). + `New in version 1.3.16.` + keep_local (bool, optional): Whether to keep local log when + :attr:`out_dir` is specified. If False, the local log will be + removed. Default: True. + `New in version 1.3.16.` + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` """ def __init__(self, @@ -34,15 +57,49 @@ def __init__(self, interval=10, ignore_last=True, reset_flag=False, - interval_exp_name=1000): + interval_exp_name=1000, + out_dir=None, + out_suffix=('.log.json', '.log', '.py'), + keep_local=True, + file_client_args=None): super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag, by_epoch) self.by_epoch = by_epoch self.time_sec_tot = 0 self.interval_exp_name = interval_exp_name + if out_dir is None and file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" when `out_dir` is not' + 'specified.') + self.out_dir = out_dir + + if not (out_dir is None or isinstance(out_dir, str) + or is_tuple_of(out_dir, str)): + raise TypeError('out_dir should be "None" or string or tuple of ' + 'string, but got {out_dir}') + self.out_suffix = out_suffix + + self.keep_local = keep_local + self.file_client_args = file_client_args + if self.out_dir is not None: + self.file_client = FileClient.infer_client(file_client_args, + self.out_dir) + def before_run(self, runner): super(TextLoggerHook, self).before_run(runner) + + if self.out_dir is not None: + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + # The final `self.out_dir` is the concatenation of `self.out_dir` + # and the last level directory of `runner.work_dir` + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + runner.logger.info( + (f'Text logs will be saved to {self.out_dir} by ' + f'{self.file_client.name} after the training process.')) + self.start_iter = runner.iter self.json_log_path = osp.join(runner.work_dir, f'{runner.timestamp}.log.json') @@ -177,3 +234,23 @@ def log(self, runner): self._log_info(log_dict, runner) self._dump_log(log_dict, runner) return log_dict + + def after_run(self, runner): + # copy or upload logs to self.out_dir + if self.out_dir is not None: + for filename in scandir(runner.work_dir, self.out_suffix, True): + local_filepath = osp.join(runner.work_dir, filename) + out_filepath = self.file_client.join_path( + self.out_dir, filename) + with open(local_filepath, 'r') as f: + self.file_client.put_text(f.read(), out_filepath) + + runner.logger.info( + (f'The file {local_filepath} has been uploaded to ' + f'{out_filepath}.')) + + if not self.keep_local: + os.remove(local_filepath) + runner.logger.info( + (f'{local_filepath} was removed due to the ' + '`self.keep_local=False`')) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index d15483c94c..30f32432a9 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -132,6 +132,10 @@ def test_error(self): def test_disk_backend(self): disk_backend = FileClient('disk') + # test `name` attribute + assert disk_backend.name == 'HardDiskBackend' + # test `allow_symlink` attribute + assert disk_backend.allow_symlink # test `get` # input path is Path object img_bytes = disk_backend.get(self.img_path) @@ -157,11 +161,19 @@ def test_disk_backend(self): filepath1 = Path(tmp_dir) / 'test.jpg' disk_backend.put(b'disk', filepath1) assert filepath1.open('rb').read() == b'disk' + # test the `mkdir_or_exist` behavior in `put` + _filepath1 = Path(tmp_dir) / 'not_existed_dir1' / 'test.jpg' + disk_backend.put(b'disk', _filepath1) + assert _filepath1.open('rb').read() == b'disk' # test `put_text` filepath2 = Path(tmp_dir) / 'test.txt' disk_backend.put_text('disk', filepath2) assert filepath2.open('r').read() == 'disk' + # test the `mkdir_or_exist` behavior in `put_text` + _filepath2 = Path(tmp_dir) / 'not_existed_dir2' / 'test.txt' + disk_backend.put_text('disk', _filepath2) + assert _filepath2.open('r').read() == 'disk' # test `isfile` assert disk_backend.isfile(filepath2) @@ -179,11 +191,11 @@ def test_disk_backend(self): assert str(filepath1) == path assert osp.isfile(filepath1) - # test `concat_paths` + # test `join_path` disk_dir = '/path/of/your/directory' - assert disk_backend.concat_paths(disk_dir, 'file') == \ + assert disk_backend.join_path(disk_dir, 'file') == \ osp.join(disk_dir, 'file') - assert disk_backend.concat_paths(disk_dir, 'dir', 'file') == \ + assert disk_backend.join_path(disk_dir, 'dir', 'file') == \ osp.join(disk_dir, 'dir', 'file') # test `list_dir_or_file` @@ -268,6 +280,9 @@ def test_disk_backend(self): def test_ceph_backend(self): ceph_backend = FileClient('ceph') + # test `allow_symlink` attribute + assert not ceph_backend.allow_symlink + # input path is Path object with pytest.raises(NotImplementedError): ceph_backend.get_text(self.text_path) @@ -305,6 +320,9 @@ def test_ceph_backend(self): def test_petrel_backend(self, backend, prefix): petrel_backend = FileClient(backend=backend, prefix=prefix) + # test `allow_symlink` attribute + assert not petrel_backend.allow_symlink + # input path is Path object img_bytes = petrel_backend.get(self.img_path) img = mmcv.imfrombytes(img_bytes) @@ -415,12 +433,12 @@ def test_petrel_backend(self, backend, prefix): assert petrel_backend.isfile(petrel_path) mock_contains.assert_called_once_with(petrel_path) - # test `concat_paths` - assert petrel_backend.concat_paths(petrel_dir, 'file') == \ + # test `join_path` + assert petrel_backend.join_path(petrel_dir, 'file') == \ f'{petrel_dir}/file' - assert petrel_backend.concat_paths(f'{petrel_dir}/', 'file') == \ + assert petrel_backend.join_path(f'{petrel_dir}/', 'file') == \ f'{petrel_dir}/file' - assert petrel_backend.concat_paths(petrel_dir, 'dir', 'file') == \ + assert petrel_backend.join_path(petrel_dir, 'dir', 'file') == \ f'{petrel_dir}/dir/file' # test `get_local_path` @@ -528,6 +546,9 @@ def test_memcached_backend(self): mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None) mc_backend = FileClient('memcached', **mc_cfg) + # test `allow_symlink` attribute + assert not mc_backend.allow_symlink + # input path is Path object with pytest.raises(NotImplementedError): mc_backend.get_text(self.text_path) @@ -550,6 +571,9 @@ def test_lmdb_backend(self): # db_path is Path object lmdb_backend = FileClient('lmdb', db_path=lmdb_path) + # test `allow_symlink` attribute + assert not lmdb_backend.allow_symlink + with pytest.raises(NotImplementedError): lmdb_backend.get_text(self.text_path) @@ -574,6 +598,9 @@ def test_http_backend(self, backend, prefix): text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ 'master/tests/data/filelist.txt' + # test `allow_symlink` attribute + assert not http_backend.allow_symlink + # input is path or Path object with pytest.raises(Exception): http_backend.get(self.img_path) @@ -659,17 +686,17 @@ def test_infer_client(self): # HardDiskBackend file_client_args = {'backend': 'disk'} client = FileClient.infer_client(file_client_args) - assert client.backend_name == 'disk' + assert client.name == 'HardDiskBackend' client = FileClient.infer_client(uri=self.img_path) - assert client.backend_name == 'disk' + assert client.name == 'HardDiskBackend' # PetrelBackend file_client_args = {'backend': 'petrel'} client = FileClient.infer_client(file_client_args) - assert client.backend_name == 'petrel' + assert client.name == 'PetrelBackend' uri = 's3://user_data' client = FileClient.infer_client(uri=uri) - assert client.backend_name == 'petrel' + assert client.name == 'PetrelBackend' def test_register_backend(self): diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py index 75aa9ddd75..9856724318 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_runner/test_checkpoint.py @@ -1,17 +1,22 @@ import sys from collections import OrderedDict from tempfile import TemporaryDirectory -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import torch import torch.nn as nn +import torch.optim as optim from torch.nn.parallel import DataParallel +from mmcv.fileio.file_client import PetrelBackend from mmcv.parallel.registry import MODULE_WRAPPERS from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix, get_state_dict, load_checkpoint, - load_from_pavi) + load_from_pavi, save_checkpoint) + +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() @MODULE_WRAPPERS.register_module() @@ -392,3 +397,36 @@ def load_from_abc(filename, map_location): filename = 'a/b/c/d' loader = CheckpointLoader._get_checkpoint_loader(filename) assert loader.__name__ == 'load_from_abc' + + +def test_save_checkpoint(tmp_path): + model = Model() + optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + # meta is not a dict + with pytest.raises(TypeError): + save_checkpoint(model, '/path/of/your/filename', meta='invalid type') + + # 1. save to disk + filename = str(tmp_path / 'checkpoint1.pth') + save_checkpoint(model, filename) + + filename = str(tmp_path / 'checkpoint2.pth') + save_checkpoint(model, filename, optimizer) + + filename = str(tmp_path / 'checkpoint3.pth') + save_checkpoint(model, filename, meta={'test': 'test'}) + + filename = str(tmp_path / 'checkpoint4.pth') + save_checkpoint(model, filename, file_client_args={'backend': 'disk'}) + + # 2. save to petrel oss + with patch.object(PetrelBackend, 'put') as mock_method: + filename = 's3://path/of/your/checkpoint1.pth' + save_checkpoint(model, filename) + mock_method.assert_called() + + with patch.object(PetrelBackend, 'put') as mock_method: + filename = 's3://path//of/your/checkpoint2.pth' + save_checkpoint( + model, filename, file_client_args={'backend': 'petrel'}) + mock_method.assert_called() diff --git a/tests/test_runner/test_eval_hook.py b/tests/test_runner/test_eval_hook.py index a746f49b55..3cbef44ef4 100644 --- a/tests/test_runner/test_eval_hook.py +++ b/tests/test_runner/test_eval_hook.py @@ -1,5 +1,6 @@ import json import os.path as osp +import sys import tempfile import unittest.mock as mock from collections import OrderedDict @@ -11,12 +12,16 @@ import torch.optim as optim from torch.utils.data import DataLoader, Dataset +from mmcv.fileio.file_client import PetrelBackend from mmcv.runner import DistEvalHook as BaseDistEvalHook from mmcv.runner import EpochBasedRunner from mmcv.runner import EvalHook as BaseEvalHook from mmcv.runner import IterBasedRunner from mmcv.utils import get_logger, scandir +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() + class ExampleDataset(Dataset): @@ -298,6 +303,34 @@ def test_eval_hook(): assert osp.exists(ckpt_path) assert runner.meta['hook_msgs']['best_score'] == -3 + # test EvalHook with specified `out_dir` + loader = DataLoader(EvalDataset()) + model = Model() + data_loader = DataLoader(EvalDataset()) + out_dir = 's3://user/data' + eval_hook = EvalHook( + data_loader, interval=1, save_best='auto', out_dir=out_dir) + + with patch.object(PetrelBackend, 'put') as mock_put, \ + patch.object(PetrelBackend, 'remove') as mock_remove, \ + patch.object(PetrelBackend, 'isfile') as mock_isfile, \ + tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 8) + + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + ckpt_path = f'{out_dir}/{basename}/best_acc_epoch_4.pth' + + assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path + assert runner.meta['hook_msgs']['best_score'] == 7 + + assert mock_put.call_count == 3 + assert mock_remove.call_count == 2 + assert mock_isfile.call_count == 2 + @patch('mmcv.engine.single_gpu_test', MagicMock) @patch('mmcv.engine.multi_gpu_test', MagicMock) diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index bb0e758504..61c347e666 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -12,7 +12,7 @@ import shutil import sys import tempfile -from unittest.mock import MagicMock, call +from unittest.mock import MagicMock, call, patch import pytest import torch @@ -20,6 +20,7 @@ from torch.nn.init import constant_ from torch.utils.data import DataLoader +from mmcv.fileio.file_client import PetrelBackend from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook, Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, @@ -34,8 +35,11 @@ OneCycleLrUpdaterHook, StepLrUpdaterHook) +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() -def test_checkpoint_hook(): + +def test_checkpoint_hook(tmp_path): """xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook.""" # test epoch based runner @@ -49,6 +53,25 @@ def test_checkpoint_hook(): runner.work_dir, 'epoch_1.pth') shutil.rmtree(runner.work_dir) + # test petrel oss when the type of runner is `EpochBasedRunner` + runner = _build_demo_runner('EpochBasedRunner', max_epochs=4) + runner.meta = dict() + out_dir = 's3://user/data' + with patch.object(PetrelBackend, 'put') as mock_put, \ + patch.object(PetrelBackend, 'remove') as mock_remove, \ + patch.object(PetrelBackend, 'isfile') as mock_isfile: + checkpointhook = CheckpointHook( + interval=1, out_dir=out_dir, by_epoch=True, max_keep_ckpts=2) + runner.register_hook(checkpointhook) + runner.run([loader], [('train', 1)]) + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + assert runner.meta['hook_msgs']['last_ckpt'] == \ + '/'.join([out_dir, basename, 'epoch_4.pth']) + mock_put.assert_called() + mock_remove.assert_called() + mock_isfile.assert_called() + shutil.rmtree(runner.work_dir) + # test iter based runner runner = _build_demo_runner( 'IterBasedRunner', max_iters=1, max_epochs=None) @@ -60,6 +83,26 @@ def test_checkpoint_hook(): runner.work_dir, 'iter_1.pth') shutil.rmtree(runner.work_dir) + # test petrel oss when the type of runner is `IterBasedRunner` + runner = _build_demo_runner( + 'IterBasedRunner', max_iters=4, max_epochs=None) + runner.meta = dict() + out_dir = 's3://user/data' + with patch.object(PetrelBackend, 'put') as mock_put, \ + patch.object(PetrelBackend, 'remove') as mock_remove, \ + patch.object(PetrelBackend, 'isfile') as mock_isfile: + checkpointhook = CheckpointHook( + interval=1, out_dir=out_dir, by_epoch=False, max_keep_ckpts=2) + runner.register_hook(checkpointhook) + runner.run([loader], [('train', 1)]) + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + assert runner.meta['hook_msgs']['last_ckpt'] == \ + '/'.join([out_dir, basename, 'iter_4.pth']) + mock_put.assert_called() + mock_remove.assert_called() + mock_isfile.assert_called() + shutil.rmtree(runner.work_dir) + def test_ema_hook(): """xdoctest -m tests/test_hooks.py test_ema_hook.""" From 41f0f43f5d8321ba39fa31c5b9bc0259dc4b5e32 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sun, 24 Oct 2021 14:27:09 +0800 Subject: [PATCH 09/30] bump version to v1.3.16 (#1430) --- mmcv/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/version.py b/mmcv/version.py index aad3d6fa81..d954f6a3be 100644 --- a/mmcv/version.py +++ b/mmcv/version.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -__version__ = '1.3.15' +__version__ = '1.3.16' def parse_version_info(version_str: str, length: int = 4) -> tuple: From d8d7e3a89e8b26ae907019f67848c3a53aa49a50 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 2 Nov 2021 20:50:50 +0800 Subject: [PATCH 10/30] [Fix]: Update test data of test_tin_shift (#1426) * Update test data of test_tin_shift * Delete tmp.engine * add pytest raises asserterror test * raise valueerror, update test log * add more comment * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmcv/ops/tin_shift.py | 5 + tests/test_ops/test_tin_shift.py | 167 ++++++++++++++++++++++++------- 2 files changed, 137 insertions(+), 35 deletions(-) diff --git a/mmcv/ops/tin_shift.py b/mmcv/ops/tin_shift.py index 74043637cc..472c9fcfe4 100644 --- a/mmcv/ops/tin_shift.py +++ b/mmcv/ops/tin_shift.py @@ -18,6 +18,11 @@ class TINShiftFunction(Function): @staticmethod def forward(ctx, input, shift): + C = input.size(2) + num_segments = shift.size(1) + if C // num_segments <= 0 or C % num_segments != 0: + raise ValueError('C should be a multiple of num_segments, ' + f'but got C={C} and num_segments={num_segments}.') ctx.save_for_backward(shift) diff --git a/tests/test_ops/test_tin_shift.py b/tests/test_ops/test_tin_shift.py index 898c46e4c4..93cea6ea58 100644 --- a/tests/test_ops/test_tin_shift.py +++ b/tests/test_ops/test_tin_shift.py @@ -14,49 +14,127 @@ cur_dir = os.path.dirname(os.path.abspath(__file__)) -inputs = ([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]], - [[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]], - [[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]], - [[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]], - [[[-0.5897, 0.7544], [1.0593, 0.8388], [-0.5732, 0.5692]], - [[-0.6766, -1.4657], [1.2362, 0.4913], [-1.1820, -1.4341]], - [[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]], - [[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]]]]) +inputs = ([[[[0.88572276, 0.46422583], [0.97408265, 0.59547687], + [0.030812204, 0.96236038], [0.75418317, 0.44058233], + [0.33279222, 0.00084149837], [0.7069388, 0.23255438], + [0.13547045, 0.81549376], [0.40174931, 0.36317211]], + [[0.57444429, 0.15905505], [0.39897251, 0.25790238], + [0.93282568, 0.18451685], [0.92526674, 0.18283755], + [0.31664443, 0.59323865], [0.1957739, 0.42505842], + [0.081158757, 0.81340349], [0.43456328, 0.30195212]], + [[0.8198145, 0.05990988], [0.98062474, 0.34803438], + [0.10412294, 0.37183142], [0.15021622, 0.038857818], + [0.40985721, 0.42253625], [0.71150124, 0.59778064], + [0.83851069, 0.15194464], [0.097513378, 0.74820143]], + [[0.80680406, 0.49327564], [0.17821097, 0.12980539], + [0.50657678, 0.14446253], [0.04178369, 0.53071898], + [0.84983683, 0.3826949], [0.32193625, 0.91275406], + [0.75628334, 0.52934098], [0.27994192, 0.3053292]]], + [[[0.082397044, 0.4210068], [0.23563534, 0.7938987], + [0.63669145, 0.69397897], [0.8844561, 0.97854084], + [0.79027033, 0.60640401], [0.63528901, 0.72172403], + [0.0097346902, 0.70800996], [0.87891227, 0.13674974]], + [[0.74329448, 0.0243572], [0.82178867, 0.85750699], + [0.7568835, 0.73146772], [0.5031184, 0.30479157], + [0.28713053, 0.47414285], [0.4682079, 0.067471564], + [0.48368263, 0.14590704], [0.25397325, 0.19946373]], + [[0.4291026, 0.068739474], [0.7159555, 0.79903615], + [0.76412082, 0.85348046], [0.081224024, 0.82264912], + [0.97173303, 0.24291694], [0.48957139, 0.43488795], + [0.67382395, 0.21889746], [0.36712623, 0.67127824]], + [[0.12054044, 0.18096751], [0.86675781, 0.54755616], + [0.68208277, 0.15164375], [0.79991871, 0.80811197], + [0.85256428, 0.68253738], [0.185983, 0.95642138], + [0.48102546, 0.28009653], [0.35726011, 0.58168036]]]]) shifts = [([[1, 0, 1, -2], [-2, 1, -1, 1]]), ([[2, 1, 2, -1], [-1, 2, 0, 2]])] -outputs = [([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]], - [[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]], - [[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]], - [[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]], - [[[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]], - [[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]], - [[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]], - [[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]]]]), - ([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]], - [[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]], - [[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]], - [[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]], - [[[-0.6766, -1.4657], [1.2362, 0.4913], [-1.1820, -1.4341]], - [[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]], - [[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]], - [[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]]]])] - -grads = [[[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]], - [[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]]], - [[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]], - [[0., 0.], [0., 0.], [0., 0.]], [[0., 0.], [0., 0.], [0., 0.]]]], - [[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]], - [[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]]], - [[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]], - [[1., 1.], [1., 1.], [1., 1.]], [[0., 0.], [0., 0.], [0., 0.]]]]] +outputs = [([[[[0.0, 0.0], [0.0, 0.0], [0.030812, 0.96236], [0.75418, 0.44058], + [0.0, 0.0], [0.0, 0.0], [0.83851, 0.15194], [0.097513, 0.7482]], + [[0.88572, 0.46423], [0.97408, 0.59548], [0.93283, 0.18452], + [0.92527, 0.18284], [0.33279, 0.0008415], [0.70694, 0.23255], + [0.75628, 0.52934], [0.27994, 0.30533]], + [[0.57444, 0.15906], [0.39897, 0.2579], [0.10412, 0.37183], + [0.15022, 0.038858], [0.31664, 0.59324], [0.19577, 0.42506], + [0.0, 0.0], [0.0, 0.0]], + [[0.81981, 0.05991], [0.98062, 0.34803], [0.50658, 0.14446], + [0.041784, 0.53072], [0.40986, 0.42254], [0.7115, 0.59778], + [0.0, 0.0], [0.0, 0.0]]], + [[[0.4291, 0.068739], [0.71596, 0.79904], [0.0, 0.0], [0.0, 0.0], + [0.28713, 0.47414], [0.46821, 0.067472], [0.0, 0.0], [0.0, + 0.0]], + [[0.12054, 0.18097], [0.86676, 0.54756], [0.63669, 0.69398], + [0.88446, 0.97854], [0.97173, 0.24292], [0.48957, 0.43489], + [0.0097347, 0.70801], [0.87891, 0.13675]], + [[0.0, 0.0], [0.0, 0.0], [0.75688, 0.73147], [0.50312, 0.30479], + [0.85256, 0.68254], [0.18598, 0.95642], [0.48368, 0.14591], + [0.25397, 0.19946]], + [[0.0, 0.0], [0.0, 0.0], [0.76412, 0.85348], [0.081224, 0.82265], + [0.0, 0.0], [0.0, 0.0], [0.67382, 0.2189], [0.36713, + 0.67128]]]]), + ([[[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], + [0.0, 0.0], [0.081159, 0.8134], [0.43456, 0.30195]], + [[0.0, 0.0], [0.0, 0.0], [0.030812, 0.96236], [0.75418, 0.44058], + [0.0, 0.0], [0.0, 0.0], [0.83851, 0.15194], [0.097513, 0.7482]], + [[0.88572, 0.46423], [0.97408, 0.59548], [0.93283, 0.18452], + [0.92527, 0.18284], [0.33279, 0.0008415], [0.70694, 0.23255], + [0.75628, 0.52934], [0.27994, 0.30533]], + [[0.57444, 0.15906], [0.39897, 0.2579], [0.10412, 0.37183], + [0.15022, 0.038858], [0.31664, 0.59324], [0.19577, 0.42506], + [0.0, 0.0], [0.0, 0.0]]], + [[[0.74329, 0.024357], [0.82179, 0.85751], [0.0, 0.0], [0.0, 0.0], + [0.79027, 0.6064], [0.63529, 0.72172], [0.0, 0.0], [0.0, 0.0]], + [[0.4291, 0.068739], [0.71596, 0.79904], [0.0, 0.0], [0.0, 0.0], + [0.28713, 0.47414], [0.46821, 0.067472], [0.0, 0.0], [0.0, + 0.0]], + [[0.12054, 0.18097], [0.86676, 0.54756], [0.63669, 0.69398], + [0.88446, 0.97854], [0.97173, 0.24292], [0.48957, 0.43489], + [0.0097347, 0.70801], [0.87891, 0.13675]], + [[0.0, 0.0], [0.0, 0.0], [0.75688, 0.73147], [0.50312, 0.30479], + [0.85256, 0.68254], [0.18598, 0.95642], [0.48368, 0.14591], + [0.25397, 0.19946]]]])] + +grads = [ + [[[[0., 0.], [0., 0.], [1., 1.], [1., 1.], [0., 0.], [0., 0.], [1., 1.], + [1., 1.]], + [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], + [1., 1.]], + [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.], + [0., 0.]], + [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.], + [0., 0.]]], + [[[1., 1.], [1., 1.], [0., 0.], [0., 0.], [1., 1.], [1., 1.], [0., 0.], + [0., 0.]], + [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], + [1., 1.]], + [[0., 0.], [0., 0.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], + [1., 1.]], + [[0., 0.], [0., 0.], [1., 1.], [1., 1.], [0., 0.], [0., 0.], [1., 1.], + [1., 1.]]]], + [[[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.], [1., 1.], + [1., 1.]], + [[0., 0.], [0., 0.], [1., 1.], [1., 1.], [0., 0.], [0., 0.], [1., 1.], + [1., 1.]], + [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], + [1., 1.]], + [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.], + [0., 0.]]], + [[[1., 1.], [1., 1.], [0., 0.], [0., 0.], [1., 1.], [1., 1.], [0., 0.], + [0., 0.]], + [[1., 1.], [1., 1.], [0., 0.], [0., 0.], [1., 1.], [1., 1.], [0., 0.], + [0., 0.]], + [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], + [1., 1.]], + [[0., 0.], [0., 0.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], + [1., 1.]]]] +] def _test_tinshift_gradcheck(dtype): try: from mmcv.ops import tin_shift except ModuleNotFoundError: - pytest.skip('TinShift op is not successfully compiled') + pytest.skip('TINShift op is not successfully compiled') if dtype == torch.half: pytest.skip('"add_cpu/sub_cpu" not implemented for Half') @@ -78,7 +156,7 @@ def _test_tinshift_allclose(dtype): try: from mmcv.ops import tin_shift except ModuleNotFoundError: - pytest.skip('TinShift op is not successfully compiled') + pytest.skip('TINShift op is not successfully compiled') for shift, output, grad in zip(shifts, outputs, grads): np_input = np.array(inputs) @@ -98,9 +176,28 @@ def _test_tinshift_allclose(dtype): x.grad.data.type(torch.float).cpu().numpy(), np_grad, 1e-3) +def _test_tinshift_assert(dtype): + try: + from mmcv.ops import tin_shift + except ModuleNotFoundError: + pytest.skip('TINShift op is not successfully compiled') + + inputs = [torch.rand(2, 3, 4, 2), torch.rand(2, 3, 4, 2)] + shifts = [torch.rand(2, 3), torch.rand(2, 5)] + + for x, shift in zip(inputs, shifts): + x = x.cuda() + shift = shift.cuda() + + # A ValueError should be raised if ops get inputs with wrong shapes. + with pytest.raises(ValueError): + tin_shift(x, shift) + + @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') @pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half]) def test_tinshift(dtype): _test_tinshift_allclose(dtype=dtype) _test_tinshift_gradcheck(dtype=dtype) + _test_tinshift_assert(dtype=dtype) From 2ea5e5ed879cc822b208eb42c48e5d91e47480ef Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Tue, 2 Nov 2021 20:57:50 +0800 Subject: [PATCH 11/30] fix the wrong function reference bug in BaseTransformerLayer when batch_first is True (#1418) --- mmcv/cnn/bricks/transformer.py | 35 ++++++++++++------------------ tests/test_cnn/test_transformer.py | 25 +++++++++++++++++++++ 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index 6e82e84fed..ed32688af4 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -102,27 +102,6 @@ def __init__(self, self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, **kwargs) - if self.batch_first: - - def _bnc_to_nbc(forward): - """Because the dataflow('key', 'query', 'value') of - ``torch.nn.MultiheadAttention`` is (num_query, batch, - embed_dims), We should adjust the shape of dataflow from - batch_first (batch, num_query, embed_dims) to num_query_first - (num_query ,batch, embed_dims), and recover ``attn_output`` - from num_query_first to batch_first.""" - - def forward_wrapper(**kwargs): - convert_keys = ('key', 'query', 'value') - for key in kwargs.keys(): - if key in convert_keys: - kwargs[key] = kwargs[key].transpose(0, 1) - attn_output, attn_output_weights = forward(**kwargs) - return attn_output.transpose(0, 1), attn_output_weights - - return forward_wrapper - - self.attn.forward = _bnc_to_nbc(self.attn.forward) self.proj_drop = nn.Dropout(proj_drop) self.dropout_layer = build_dropout( @@ -199,6 +178,17 @@ def forward(self, if key_pos is not None: key = key + key_pos + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + out = self.attn( query=query, key=key, @@ -206,6 +196,9 @@ def forward(self, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] + if self.batch_first: + out = out.transpose(0, 1) + return identity + self.dropout_layer(self.proj_drop(out)) diff --git a/tests/test_cnn/test_transformer.py b/tests/test_cnn/test_transformer.py index a4a5f62e9c..106753b423 100644 --- a/tests/test_cnn/test_transformer.py +++ b/tests/test_cnn/test_transformer.py @@ -1,3 +1,5 @@ +import copy + import pytest import torch @@ -5,6 +7,7 @@ from mmcv.cnn.bricks.transformer import (FFN, BaseTransformerLayer, MultiheadAttention, TransformerLayerSequence) +from mmcv.runner import ModuleList def test_multiheadattention(): @@ -92,6 +95,28 @@ def test_ffn(): ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available') +def test_basetransformerlayer_cuda(): + # To test if the BaseTransformerLayer's behaviour remains + # consistent after being deepcopied + operation_order = ('self_attn', 'ffn') + baselayer = BaseTransformerLayer( + operation_order=operation_order, + batch_first=True, + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + ), + ) + baselayers = ModuleList([copy.deepcopy(baselayer) for _ in range(2)]) + baselayers.to('cuda') + x = torch.rand(2, 10, 256).cuda() + for m in baselayers: + x = m(x) + assert x.shape == torch.Size([2, 10, 256]) + + def test_basetransformerlayer(): attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8), feedforward_channels = 2048 From 115de6734a2dc5e203a45255588831a8dd410b66 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Tue, 2 Nov 2021 20:59:26 +0800 Subject: [PATCH 12/30] [Docs] Add mmcv itself in the docs list (#1441) * Add mmcv itself in the docs list * modify link of docs --- docs/conf.py | 27 ++++++++++++++++----------- docs_zh_CN/conf.py | 27 +++++++++++++++++---------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 371ae3de7c..5980a1f65b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -111,47 +111,52 @@ }, { 'name': - 'Projects', + 'Docs', 'children': [ + { + 'name': 'MMCV', + 'url': 'https://mmcv.readthedocs.io/en/latest/', + }, { 'name': 'MMAction2', - 'url': 'https://github.com/open-mmlab/mmaction2', + 'url': 'https://mmaction2.readthedocs.io/en/latest/', }, { 'name': 'MMClassification', - 'url': 'https://github.com/open-mmlab/mmclassification', + 'url': + 'https://mmclassification.readthedocs.io/en/latest/', }, { 'name': 'MMDetection', - 'url': 'https://github.com/open-mmlab/mmdetection', + 'url': 'https://mmdetection.readthedocs.io/en/latest/', }, { 'name': 'MMDetection3D', - 'url': 'https://github.com/open-mmlab/mmdetection3d', + 'url': 'https://mmdetection3d.readthedocs.io/en/latest/', }, { 'name': 'MMEditing', - 'url': 'https://github.com/open-mmlab/mmediting', + 'url': 'https://mmediting.readthedocs.io/en/latest/', }, { 'name': 'MMGeneration', - 'url': 'https://github.com/open-mmlab/mmgeneration', + 'url': 'https://mmgeneration.readthedocs.io/en/latest/', }, { 'name': 'MMOCR', - 'url': 'https://github.com/open-mmlab/mmocr', + 'url': 'https://mmocr.readthedocs.io/en/latest/', }, { 'name': 'MMPose', - 'url': 'https://github.com/open-mmlab/mmpose', + 'url': 'https://mmpose.readthedocs.io/en/latest/', }, { 'name': 'MMSegmentation', - 'url': 'https://github.com/open-mmlab/mmsegmentation', + 'url': 'https://mmsegmentation.readthedocs.io/en/latest/', }, { 'name': 'MMTracking', - 'url': 'https://github.com/open-mmlab/mmtracking', + 'url': 'https://mmtracking.readthedocs.io/en/latest/', }, ] }, diff --git a/docs_zh_CN/conf.py b/docs_zh_CN/conf.py index b004726668..096a3c24ac 100644 --- a/docs_zh_CN/conf.py +++ b/docs_zh_CN/conf.py @@ -111,11 +111,15 @@ }, { 'name': - '算法库', + '文档', 'children': [ + { + 'name': 'MMCV', + 'url': 'https://mmcv.readthedocs.io/zh_CN/latest/', + }, { 'name': 'MMAction2', - 'url': 'https://github.com/open-mmlab/mmaction2', + 'url': 'https://mmaction2.readthedocs.io/zh_CN/latest/', }, { 'name': 'MMClassification', @@ -123,35 +127,38 @@ }, { 'name': 'MMDetection', - 'url': 'https://github.com/open-mmlab/mmdetection', + 'url': + 'https://mmclassification.readthedocs.io/zh_CN/latest/', }, { 'name': 'MMDetection3D', - 'url': 'https://github.com/open-mmlab/mmdetection3d', + 'url': + 'https://mmdetection3d.readthedocs.io/zh_CN/latest/', }, { 'name': 'MMEditing', - 'url': 'https://github.com/open-mmlab/mmediting', + 'url': 'https://mmediting.readthedocs.io/zh_CN/latest/', }, { 'name': 'MMGeneration', - 'url': 'https://github.com/open-mmlab/mmgeneration', + 'url': 'https://mmgeneration.readthedocs.io/en/latest/', }, { 'name': 'MMOCR', - 'url': 'https://github.com/open-mmlab/mmocr', + 'url': 'https://mmocr.readthedocs.io/zh_CN/latest/', }, { 'name': 'MMPose', - 'url': 'https://github.com/open-mmlab/mmpose', + 'url': 'https://mmpose.readthedocs.io/zh_CN/latest/', }, { 'name': 'MMSegmentation', - 'url': 'https://github.com/open-mmlab/mmsegmentation', + 'url': + 'https://mmsegmentation.readthedocs.io/zh_CN/latest/', }, { 'name': 'MMTracking', - 'url': 'https://github.com/open-mmlab/mmtracking', + 'url': 'https://mmtracking.readthedocs.io/zh_CN/latest/', }, ] }, From e59d3e02133e4a9d35eb89f837bf211f45f0b4b0 Mon Sep 17 00:00:00 2001 From: WangJiaZhen <47851024+teamwong111@users.noreply.github.com> Date: Tue, 2 Nov 2021 20:52:23 +0800 Subject: [PATCH 13/30] [Improve] improve checkpoint loading log (#1446) --- mmcv/runner/base_runner.py | 2 -- mmcv/runner/checkpoint.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mmcv/runner/base_runner.py b/mmcv/runner/base_runner.py index d7ed4cb713..25cd98f51c 100644 --- a/mmcv/runner/base_runner.py +++ b/mmcv/runner/base_runner.py @@ -334,8 +334,6 @@ def load_checkpoint(self, map_location='cpu', strict=False, revise_keys=[(r'^module.', '')]): - - self.logger.info('load checkpoint from %s', filename) return load_checkpoint( self.model, filename, diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 4db75d23f7..6ad605b854 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -240,7 +240,8 @@ def load_checkpoint(cls, filename, map_location=None, logger=None): checkpoint_loader = cls._get_checkpoint_loader(filename) class_name = checkpoint_loader.__name__ - mmcv.print_log(f'Use {class_name} loader', logger) + mmcv.print_log( + f'load checkpoint from {class_name[10:]} path: {filename}', logger) return checkpoint_loader(filename, map_location) From 682dba5b0a9571635d156fb488c40a6a939960bc Mon Sep 17 00:00:00 2001 From: Yuxin Liu Date: Mon, 8 Nov 2021 15:24:33 +0800 Subject: [PATCH 14/30] [Feature] Support SigmoidFocalLoss with Cambricon MLU backend (#1346) * [Feature] Support SigmoidFocalLoss with Cambricon MLU backend * refactor MMCV_WITH_MLU macro define * refactor NFU_ALIGN_SIZE, PAD_DOWN and split_pipeline_num * delete extra fool proofing in cpp * [Feature] Support SigmoidFocalLossBackward with Cambricon MLU backend * fix macro definition in SigmoidFocalLoss * refactor mlu files into clang-format * refactor sigmoid focal loss test * refactor Sigmoid Focal Loss file structure. * fix python lint error * fix import torch_mlu error type * fix lint * refactor clang format style to google Co-authored-by: zhouzaida --- .../ops/csrc/common/mlu/common_mlu_helper.hpp | 36 + .../mlu/focal_loss_sigmoid_mlu_kernel.mlu | 776 ++++++++++++++++++ mmcv/ops/csrc/common/pytorch_cpp_helper.hpp | 4 +- mmcv/ops/csrc/common/pytorch_mlu_helper.hpp | 26 + mmcv/ops/csrc/pytorch/focal_loss.cpp | 44 + .../pytorch/mlu/focal_loss_sigmoid_mlu.cpp | 315 +++++++ mmcv/ops/focal_loss.py | 3 +- mmcv/utils/__init__.py | 1 + mmcv/utils/pytorch_wrapper.py | 17 + setup.py | 29 +- tests/test_ops/test_focal_loss.py | 41 +- 11 files changed, 1275 insertions(+), 17 deletions(-) create mode 100644 mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp create mode 100644 mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu create mode 100644 mmcv/ops/csrc/common/pytorch_mlu_helper.hpp create mode 100644 mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp create mode 100644 mmcv/utils/pytorch_wrapper.py diff --git a/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp b/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp new file mode 100644 index 0000000000..826c42349f --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (C) 2021 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef UTILS_H_ +#define UTILS_H_ + +#define NFU_ALIGN_SIZE 128 // Byte +#define REM_FOR_STACK (128 * 1024) // 128KB reserved for cncc + +#ifdef __BANG_ARCH__ +#define MAX_NRAM_SIZE \ + (__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK) // 128KB reserved for cncc +#define MAX_SRAM_SIZE \ + (__MLU_SRAM_SIZE__ * 1024 - REM_FOR_STACK) // 128KB reserved for cncc +#else +#define MAX_NRAM_SIZE (384 * 1024) // 384KB, initialization value +#define MAX_SRAM_SIZE (1920 * 1024) // 1920KB, initialization value +#endif + +#ifndef PAD_UP +#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y)) +#endif + +#ifndef PAD_DOWN +#define PAD_DOWN(x, y) (((x) / (y)) * (y)) +#endif + +#endif // UTILS_H_ diff --git a/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu new file mode 100644 index 0000000000..028f6c0c9d --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu @@ -0,0 +1,776 @@ +/************************************************************************* + * Copyright (C) 2021 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include + +#include "common_mlu_helper.hpp" + +#define PING 0 +#define PONG 1 + +__nram__ char nram_buffer[MAX_NRAM_SIZE]; + +namespace forward { +template +__mlu_func__ void loadInput(char *nram_input, T *dram_input, const int32_t size, + const int32_t dst_stride = 0, + const int32_t src_stride = 0, + const int32_t count = 1) { + __memcpy_async(nram_input, dram_input, size, GDRAM2NRAM, dst_stride, + src_stride, count - 1); +} + +template +__mlu_func__ void storeOutput(T *dram_output, char *nram_output, + const int32_t size, const int32_t dst_stride = 0, + const int32_t src_stride = 0, + const int32_t count = 1) { + __memcpy_async(dram_output, nram_output, size, NRAM2GDRAM, dst_stride, + src_stride, count - 1); +} + +template +__mlu_func__ void compute(T *input, const int32_t *target, const T *weight, + const int32_t has_weight, const int32_t deal_num, + const int32_t n_seg, const int32_t C, float alpha, + float gamma, T *scalar_temp, T *tensor_max, + T *tensor_temp, T *output) { + const int32_t scalar_elem_num = NFU_ALIGN_SIZE / sizeof(T); + + // 0. n_max = max(0, x) + __nramset((T *)tensor_max, deal_num, (T)0); + __bang_cycle_maxequal((T *)tensor_max, (T *)tensor_max, (T *)input, deal_num, + deal_num); + + // 1. ln(1+e^x) = ln(e^(-max) + e^(x-max)) + max + __nramset((T *)scalar_temp, scalar_elem_num, (T)-1); + __bang_cycle_mul((T *)tensor_temp, (T *)tensor_max, (T *)scalar_temp, + deal_num, scalar_elem_num); + __bang_cycle_add((T *)output, (T *)input, (T *)tensor_temp, deal_num, + deal_num); + __bang_active_exphp((T *)output, (T *)output, deal_num); + __bang_active_exphp((T *)tensor_temp, (T *)tensor_temp, deal_num); + __bang_cycle_add((T *)output, (T *)output, (T *)tensor_temp, deal_num, + deal_num); + __bang_active_loghp((T *)output, (T *)output, deal_num); + __bang_cycle_add((T *)output, (T *)output, (T *)tensor_max, deal_num, + deal_num); + + // 2. temp = [1 + e^(-x)] ^ (-r) + __nramset((T *)scalar_temp, scalar_elem_num, (T)-1); + __bang_cycle_mul((T *)tensor_temp, (T *)input, (T *)scalar_temp, deal_num, + scalar_elem_num); + __bang_active_exphp((T *)tensor_temp, (T *)tensor_temp, deal_num); + + __nramset((T *)scalar_temp, scalar_elem_num, (T)1); + __bang_cycle_add((T *)tensor_temp, (T *)tensor_temp, (T *)scalar_temp, + deal_num, scalar_elem_num); + __bang_active_loghp((T *)tensor_temp, (T *)tensor_temp, deal_num); + + __nramset((T *)scalar_temp, scalar_elem_num, (T)(-gamma)); + __bang_cycle_mul((T *)tensor_temp, (T *)tensor_temp, (T *)scalar_temp, + deal_num, scalar_elem_num); + __bang_active_exphp((T *)tensor_temp, (T *)tensor_temp, deal_num); + + // 3.1 output: target != j + __nramset((T *)scalar_temp, scalar_elem_num, (T)(1 - alpha)); + __bang_cycle_mul((T *)output, (T *)output, (T *)scalar_temp, deal_num, + scalar_elem_num); + __bang_cycle_mul((T *)output, (T *)output, (T *)tensor_temp, deal_num, + deal_num); + + // 3.2 output: target == j + const int32_t c_align_size = PAD_UP((sizeof(T) * C), NFU_ALIGN_SIZE); + for (int32_t i = 0; i < n_seg; ++i) { + const int32_t target_value = *((int32_t *)target + i); + if (target_value >= 0 && target_value < C) { + const int32_t offset = i * c_align_size + target_value * sizeof(T); + char *addr_input = (char *)input + offset; + char *addr_output = (char *)output + offset; + const float x = *(T *)addr_input; + const float p = 1. / (1. + exp(-x)); + *(T *)addr_output = -alpha * pow(1. - p, gamma) * log(fmax(p, FLT_MIN)); + } + } + + // with weight + if (has_weight > 0) { + int32_t row_num_elem = deal_num / n_seg; + for (int32_t i = 0; i < n_seg; ++i) { + const int32_t t = *((int32_t *)target + i); + __nramset((T *)scalar_temp, scalar_elem_num, *((T *)weight + t)); + __bang_cycle_mul((T *)output + i * row_num_elem, + (T *)output + i * row_num_elem, (T *)scalar_temp, + row_num_elem, scalar_elem_num); + } + } +} + +template +__mlu_func__ void focalLossSigmoidForwardBlock( + const T *input, const int32_t *target, const T *weight, + const int32_t row_num, const int32_t C, const float alpha, + const float gamma, T *output) { + /* + * NRAM partition + * |-----------------------------------------------------------------------| + * | scalar | + * |-----------------------------------------------------------------------| + * | weight | + * |------------------------------- COMPUTE -------------------------------| + * | | | + * | computeA | computeB | + * | | | + * |------------- PING ------------------------------- PONG ---------------| + * | | | + * | input | input | + * | | | + * |-----------------------------------|-----------------------------------| + * | | | + * | output | output | + * | | | + * |-----------------------------------|-----------------------------------| + * | target | target | + * |-----------------------------------|-----------------------------------| + * + * split_pipeline_num is 6: COMPUTE(computeA,computeB), PING(input,output), + * PONG(input,output). + * split_target_num is 2: PING(target), PONG(target). + */ + const int32_t c_align = PAD_UP(C, NFU_ALIGN_SIZE / sizeof(T)); + const int32_t c_align_size = c_align * sizeof(T); + const int32_t scalar_size = NFU_ALIGN_SIZE; + const int32_t weight_size = (weight != NULL) * c_align_size; + const int32_t split_pipeline_num = 6; + const int32_t split_target_num = 2; + + const int32_t remain_size = MAX_NRAM_SIZE - scalar_size - weight_size; + const int32_t n_seg = remain_size / (split_pipeline_num * c_align_size + + split_target_num * sizeof(int32_t)); + const int32_t deal_num = n_seg * c_align_size / sizeof(T); + const int32_t target_size = n_seg * sizeof(int32_t); + + // nram scalar,weight + char *nram_scalar = (char *)nram_buffer; + char *nram_weight = (char *)nram_scalar + scalar_size; + if (weight_size > 0) { + loadInput(nram_weight, (T *)weight, C * sizeof(T)); + __asm__ volatile("sync;"); + } + + // nram COMPUTE + const int32_t compute_size = 2 * c_align_size * n_seg; + char *nram_compute_a = (char *)nram_weight + weight_size; + char *nram_compute_b = (char *)nram_compute_a + c_align_size * n_seg; + + // nram PING/PONG + const int32_t pingpong_offset = (remain_size - compute_size) / 2; + char *nram_input = (char *)nram_compute_a + 2 * c_align_size * n_seg; + char *nram_output = (char *)nram_compute_a + 3 * c_align_size * n_seg; + char *nram_target = (char *)nram_compute_a + 4 * c_align_size * n_seg; + + const int32_t repeat = row_num / n_seg; + const int32_t remain = row_num % n_seg; + + /* + * Pipeline: The pipeline is processed in three stages: Load, Compute, Store. + * The allocated memory space of NRAM is divided into two parts: + * PING and Pong. In a single time slice, PING is used to process + * IO stream and PONG is used for computation. Both of them are + * processed synchronously until finished. + * + * diagram of PINGPONG: + * |------|-----------------------------------------------------------------| + * | | space | + * |------|-----------------------------------------------------------------| + * | time | Ping | Pong | Ping | Pong | Ping | Pong | + * |------|-----------------------------------------------------------------| + * | 0 | L0 | | | | | | + * | 1 | C0 | L1 | | | | | + * | 2 | S0 | C1 | L2 | | | | + * | 3 | | S1 | C2 | L3 | | | + * | 4 | | | S2 | C3 | L4 | | + * | 5 | | | | S3 | C4 | L5 | + * | 6 | | | | | S4 | C5 | + * | 7 | | | | | | S5 | + * |------|-----------------------------------------------------------------| + */ + + // diagram of PINGPONG: L0 + if (repeat > 0) { + loadInput(nram_input, (T *)input, C * sizeof(T), c_align * sizeof(T), + C * sizeof(T), n_seg); + loadInput(nram_target, (int32_t *)target, target_size); + __asm__ volatile("sync;"); + } + + // diagram of PINGPONG: C0 and L1 + if (repeat > 1) { + loadInput(nram_input + pingpong_offset, (T *)input + C * n_seg, + C * sizeof(T), c_align * sizeof(T), C * sizeof(T), n_seg); + loadInput(nram_target + pingpong_offset, (int32_t *)target + n_seg, + target_size); + compute((T *)nram_input, (int32_t *)nram_target, (T *)nram_weight, + weight_size, deal_num, n_seg, C, alpha, gamma, (T *)nram_scalar, + (T *)nram_compute_a, (T *)nram_compute_b, (T *)nram_output); + __asm__ volatile("sync;"); + } + + for (int32_t i = 0; i < repeat - 2; ++i) { + storeOutput((T *)output + i * C * n_seg, + nram_output + (i % 2) * pingpong_offset, C * sizeof(T), + C * sizeof(T), c_align * sizeof(T), n_seg); + loadInput(nram_input + (i % 2) * pingpong_offset, + (T *)input + (i + 2) * C * n_seg, C * sizeof(T), + c_align * sizeof(T), C * sizeof(T), n_seg); + loadInput(nram_target + (i % 2) * pingpong_offset, + (int32_t *)target + (i + 2) * n_seg, target_size); + compute((T *)(nram_input + ((i + 1) % 2) * pingpong_offset), + (int32_t *)(nram_target + ((i + 1) % 2) * pingpong_offset), + (T *)nram_weight, weight_size, deal_num, n_seg, C, alpha, gamma, + (T *)nram_scalar, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)(nram_output + ((i + 1) % 2) * pingpong_offset)); + __asm__ volatile("sync;"); + } + + if (repeat > 1) { + storeOutput((T *)output + (repeat - 2) * C * n_seg, + nram_output + (repeat % 2) * pingpong_offset, C * sizeof(T), + C * sizeof(T), c_align * sizeof(T), n_seg); + } + if (remain > 0) { + loadInput(nram_input + (repeat % 2) * pingpong_offset, + (T *)input + repeat * C * n_seg, C * sizeof(T), + c_align * sizeof(T), C * sizeof(T), remain); + loadInput(nram_target + (repeat % 2) * pingpong_offset, + (int32_t *)target + repeat * n_seg, + remain * sizeof(int32_t)); + } + if (repeat > 0) { + compute((T *)(nram_input + ((repeat - 1) % 2) * pingpong_offset), + (int32_t *)(nram_target + ((repeat - 1) % 2) * pingpong_offset), + (T *)nram_weight, weight_size, deal_num, n_seg, C, alpha, gamma, + (T *)nram_scalar, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)(nram_output + ((repeat - 1) % 2) * pingpong_offset)); + } + __asm__ volatile("sync;"); + + if (repeat > 0) { + storeOutput((T *)output + (repeat - 1) * C * n_seg, + nram_output + ((repeat - 1) % 2) * pingpong_offset, + C * sizeof(T), C * sizeof(T), c_align * sizeof(T), n_seg); + } + if (remain > 0) { + int rem_deal_num = remain * c_align_size / sizeof(T); + compute((T *)(nram_input + (repeat % 2) * pingpong_offset), + (int32_t *)(nram_target + (repeat % 2) * pingpong_offset), + (T *)nram_weight, weight_size, rem_deal_num, remain, C, alpha, + gamma, (T *)nram_scalar, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)(nram_output + (repeat % 2) * pingpong_offset)); + __asm__ volatile("sync;"); + + storeOutput((T *)output + repeat * C * n_seg, + nram_output + (repeat % 2) * pingpong_offset, C * sizeof(T), + C * sizeof(T), c_align * sizeof(T), remain); + } +} + +template +__mlu_global__ void MLUUnion1KernelFocalLossSigmoidForward( + const void *input, const void *target, const void *weight, const int32_t N, + const int32_t C, const float alpha, const float gamma, void *output) { + const int32_t n_seg = N / taskDim + (taskId == taskDim - 1) * (N % taskDim); + const T *input_offset = (T *)input + N / taskDim * taskId * C; + const int32_t *target_offset = (int32_t *)target + N / taskDim * taskId; + T *output_offset = (T *)output + N / taskDim * taskId * C; + + focalLossSigmoidForwardBlock((T *)input_offset, (int32_t *)target_offset, + (T *)weight, n_seg, C, alpha, gamma, + (T *)output_offset); +} +} // namespace forward + +namespace backward { +template +__mlu_func__ void loadInput(char *nram_input, char *nram_target, + const T *gdram_input, const int32_t *gdram_target, + const int32_t deal_n, const int32_t total_c, + const bool pingping_flag, const bool has_weight, + const int32_t nram_offset, + const int32_t gdram_offset) { + if (pingping_flag == PONG) { + nram_input += nram_offset; + nram_target += nram_offset; + } + + __memcpy_async(nram_target, gdram_target + gdram_offset / total_c, + deal_n * sizeof(int32_t), GDRAM2NRAM); + + char *nram_input_load = nram_input; + int32_t compute_align_size = 2 * NFU_ALIGN_SIZE; + if (has_weight) { + if (sizeof(T) == sizeof(half)) { + int32_t compute_align_num = compute_align_size / sizeof(float); + int32_t align_c = PAD_UP(total_c, compute_align_num); + int32_t compute_size = deal_n * align_c * sizeof(float); + nram_input_load += compute_size / 2; + } + int32_t align_c = PAD_UP(total_c, NFU_ALIGN_SIZE / sizeof(T)); + int32_t total_c_size = total_c * sizeof(T); + int32_t align_c_size = align_c * sizeof(T); + __memcpy_async(nram_input_load, gdram_input + gdram_offset, total_c_size, + GDRAM2NRAM, align_c_size, total_c_size, deal_n - 1); + } else { + if (sizeof(T) == sizeof(half)) { + int32_t compute_size = + PAD_UP(deal_n * total_c * sizeof(float), compute_align_size); + nram_input_load += compute_size / 2; + } + int32_t load_size = deal_n * total_c * sizeof(T); + __memcpy_async(nram_input_load, gdram_input + gdram_offset, load_size, + GDRAM2NRAM); + } +} + +template +__mlu_func__ void sigmoid(T *dst_data, const T *src_data, + const int32_t elem_count) { + __bang_mul_const(dst_data, (T *)src_data, T(-1), elem_count); + __bang_active_exphp(dst_data, dst_data, elem_count); + __bang_add_const(dst_data, dst_data, T(1), elem_count); + __bang_active_reciphp(dst_data, dst_data, elem_count); +} + +template +__mlu_func__ void coreCompute(char *nram_input, const T *nram_weight, + const float *nram_flt_min, char *nram_pt, + char *nram_alpha_t, char *nram_temp, + char *nram_target, const float *nram_gamma, + char *nram_output, const float alpha, + const int32_t compute_num, const int32_t deal_n, + const int32_t total_c, const bool pingpong_flag, + const int32_t nram_offset, + const bool has_weight) { + if (pingpong_flag == PONG) { + nram_input += nram_offset; + nram_pt += nram_offset; + nram_alpha_t += nram_offset; + nram_temp += nram_offset; + nram_output += nram_offset; + nram_target += nram_offset; + } + + if (sizeof(T) == sizeof(half)) { + const int32_t compute_size = compute_num * sizeof(float); + char *nram_input_load = nram_input + compute_size / 2; + __bang_half2float((float *)nram_input, (half *)nram_input_load, + compute_num); + } + + // 0. alpha_t = alpha - 1 + __nramset((float *)nram_alpha_t, compute_num, (float)(alpha - 1.0)); + + // 1. pt = 1 - sigmoid(x) + sigmoid((float *)nram_pt, (float *)nram_input, compute_num); + __bang_mul_const((float *)nram_pt, (float *)nram_pt, (float)(-1), + compute_num); + __bang_add_const((float *)nram_pt, (float *)nram_pt, (float)1, compute_num); + + // 2. pt = target[n] == c ? sigmoid(x) : 1 - sigmoid(x) + // alpha_t = target[n] == c ? alpha : alpha - 1 + const int32_t nfu_align_num = NFU_ALIGN_SIZE / sizeof(float); + for (int n = 0; n < deal_n; n++) { + const int32_t target_value = ((int32_t *)nram_target)[n]; + if (target_value >= total_c || target_value < 0) continue; + int32_t c_offset = 0; + if (has_weight) { + int32_t c_align_num = nfu_align_num; + if (sizeof(T) == sizeof(half)) { + c_align_num += nfu_align_num; + } + c_offset = PAD_UP(total_c, c_align_num); + } else { + c_offset = total_c; + } + int32_t idx = n * c_offset + target_value; + *((float *)nram_pt + idx) = 1.0 - *((float *)nram_pt + idx); + *((float *)nram_alpha_t + idx) = alpha; + } + + // 3. temp = -alpha_t * e^(gamma * log(max(1 - pt, FLT_MIN)) + __bang_mul_const((float *)nram_temp, (float *)nram_pt, (float)(-1), + compute_num); + __bang_add_const((float *)nram_temp, (float *)nram_temp, (float)(1), + compute_num); + __bang_cycle_maxequal((float *)nram_temp, (float *)nram_temp, + (float *)nram_flt_min, compute_num, nfu_align_num); + __bang_active_loghp((float *)nram_temp, (float *)nram_temp, compute_num); + __bang_cycle_mul((float *)nram_temp, (float *)nram_temp, (float *)nram_gamma, + compute_num, nfu_align_num); + __bang_active_exphp((float *)nram_temp, (float *)nram_temp, compute_num); + __bang_mul((float *)nram_temp, (float *)nram_temp, (float *)nram_alpha_t, + compute_num); + __bang_mul_const((float *)nram_temp, (float *)nram_temp, (float)(-1), + compute_num); + + // 4. output = 1 - pt - gamma * pt * log(max(pt, FLT_MIN)) + __bang_cycle_maxequal((float *)nram_output, (float *)nram_pt, + (float *)nram_flt_min, compute_num, nfu_align_num); + __bang_active_loghp((float *)nram_output, (float *)nram_output, compute_num); + __bang_mul((float *)nram_output, (float *)nram_output, (float *)nram_pt, + compute_num); + __bang_cycle_mul((float *)nram_output, (float *)nram_output, + (float *)nram_gamma, compute_num, nfu_align_num); + __bang_add((float *)nram_output, (float *)nram_output, (float *)nram_pt, + compute_num); + __bang_mul_const((float *)nram_output, (float *)nram_output, (float)(-1), + compute_num); + __bang_add_const((float *)nram_output, (float *)nram_output, (float)(1), + compute_num); + + // 5. output = output * temp + __bang_mul((float *)nram_output, (float *)nram_output, (float *)nram_temp, + compute_num); + + if (sizeof(T) == sizeof(half)) { + __bang_float2half_rd((half *)nram_output, (float *)nram_output, + compute_num); + } + + if (has_weight) { + // with weight + for (int n = 0; n < deal_n; n++) { + int32_t c_align_num = nfu_align_num; + if (sizeof(T) == sizeof(half)) { + c_align_num += nfu_align_num; + } + int32_t align_c = PAD_UP(total_c, c_align_num); + int32_t target_value = ((int32_t *)nram_target)[n]; + T weight_value = nram_weight[target_value]; + __bang_mul_const((T *)nram_output + n * align_c, + (T *)nram_output + n * align_c, weight_value, align_c); + } + } +} + +template +__mlu_func__ void storeOutput(T *gdram_output, const char *nram_output, + const int32_t deal_n, const int32_t total_c, + const bool pingpong_flag, const bool has_weight, + const int32_t nram_offset, + const int32_t gdram_offset) { + if (pingpong_flag == PONG) { + nram_output += nram_offset; + } + const int32_t store_size = deal_n * total_c * sizeof(T); + if (has_weight) { + int32_t align_c = PAD_UP(total_c, NFU_ALIGN_SIZE / sizeof(T)); + int32_t total_c_size = total_c * sizeof(T); + int32_t align_c_size = align_c * sizeof(T); + __memcpy_async(gdram_output + gdram_offset, nram_output, total_c_size, + NRAM2GDRAM, total_c_size, align_c_size, deal_n - 1); + } else { + __memcpy_async(gdram_output + gdram_offset, nram_output, store_size, + NRAM2GDRAM); + } +} + +template +__mlu_func__ void focalLossSigmoidBackwardBlock( + const T *input, const int32_t *target, const T *weight, const float gamma, + const float alpha, const int32_t total_n, const int32_t deal_n, + const int32_t total_c, T *output) { + // params per time slice + int32_t deal_num = deal_n * total_c; + int32_t deal_size = deal_num * sizeof(float); + int32_t compute_num = 0; + int32_t compute_size = 0; + int32_t compute_align_size = NFU_ALIGN_SIZE; + const int32_t nfu_align_num = NFU_ALIGN_SIZE / sizeof(T); + if (sizeof(T) == sizeof(half)) { + compute_align_size += NFU_ALIGN_SIZE; + } + const int32_t compute_align_num = compute_align_size / sizeof(float); + bool has_weight = false; + if (weight != NULL) { + has_weight = true; + int32_t align_c = PAD_UP(total_c, compute_align_num); + compute_num = deal_n * align_c; + compute_size = compute_num * sizeof(float); + } else { + compute_size = PAD_UP(deal_size, compute_align_size); + compute_num = compute_size / sizeof(float); + } + + // params per core + int32_t total_num = total_n * total_c; + int32_t num_per_core = PAD_DOWN(total_num / taskDim, deal_num); + int32_t loop_per_core = num_per_core / deal_num; + + /* NRAM partition: + * + * |-----------------ping pong--------------------| + * |input | pt | alpha_t | temp | output | target | flt_min | gamma | weight| + * + * split_pipeline_num is 5: input, pt, alpha_t, temp, output. + * nram_reserved_line_num is 2: flt_min, gamma. + */ + const int32_t split_pipeline_num = 5; + const int32_t nram_reserved_line_num = 2; + int32_t target_deal_size = deal_n * sizeof(int32_t); + int32_t target_deal_size_align = PAD_UP(target_deal_size, NFU_ALIGN_SIZE); + // nram PING/PONG offset + int32_t ping_pong_offset = + compute_size * split_pipeline_num + target_deal_size_align; + + // gdram addr + int32_t *base_addr_target = + (int32_t *)target + taskId * loop_per_core * deal_n; + T *base_addr_input = (T *)input + taskId * num_per_core; + T *base_addr_output = output + taskId * num_per_core; + + // nram addr + char *nram_input = (char *)nram_buffer; + char *nram_pt = nram_input + compute_size; + char *nram_alpha_t = nram_pt + compute_size; + char *nram_temp = nram_alpha_t + compute_size; + char *nram_output = nram_temp + compute_size; + char *nram_target = nram_output + compute_size; + float *nram_flt_min = NULL; + float *nram_gamma = NULL; + T *nram_weight = NULL; + + if (!has_weight) { + nram_flt_min = (float *)(nram_buffer + MAX_NRAM_SIZE - + nram_reserved_line_num * NFU_ALIGN_SIZE); + nram_gamma = nram_flt_min + nfu_align_num; + } else { + int32_t weight_space = PAD_UP(total_c * sizeof(T), NFU_ALIGN_SIZE); + nram_flt_min = + (float *)(nram_buffer + MAX_NRAM_SIZE - + nram_reserved_line_num * NFU_ALIGN_SIZE - weight_space); + nram_gamma = nram_flt_min + nfu_align_num; + nram_weight = (T *)(nram_gamma + nfu_align_num); + __memcpy_async(nram_weight, weight, total_c * sizeof(T), GDRAM2NRAM); + } + + // nram set gamma and FLT_MIN + __nramset(nram_gamma, nfu_align_num, gamma); + __nramset(nram_flt_min, nfu_align_num, FLT_MIN); + + /* + * Pipeline: The pipeline is processed in three stages: Load, Compute, Store. + * The allocated memory space of NRAM is divided into two parts: + * PING and Pong. In a single time slice, PING is used to process + * IO stream and PONG is used for computation. Both of them are + * processed synchronously until finished. + * + * diagram of PINGPONG: + * |------|-----------------------------------------------------------------| + * | | space | + * |------|-----------------------------------------------------------------| + * | time | Ping | Pong | Ping | Pong | Ping | Pong | + * |------|-----------------------------------------------------------------| + * | 0 | L0 | | | | | | + * | 1 | C0 | L1 | | | | | + * | 2 | S0 | C1 | L2 | | | | + * | 3 | | S1 | C2 | L3 | | | + * | 4 | | | S2 | C3 | L4 | | + * | 5 | | | | S3 | C4 | L5 | + * | 6 | | | | | S4 | C5 | + * | 7 | | | | | | S5 | + * |------|-----------------------------------------------------------------| + */ + + // diagram of PINGPONG: L0 + if (loop_per_core > 0) { + loadInput(nram_input, nram_target, base_addr_input, base_addr_target, + deal_n, total_c, PING, has_weight, ping_pong_offset, 0); + __asm__ volatile("sync;"); + } + + // diagram of PINGPONG: C0 and L1 + if (loop_per_core > 1) { + coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t, + nram_temp, nram_target, nram_gamma, nram_output, alpha, + compute_num, deal_n, total_c, PING, ping_pong_offset, + has_weight); + loadInput(nram_input, nram_target, base_addr_input, base_addr_target, + deal_n, total_c, PONG, has_weight, ping_pong_offset, deal_num); + __asm__ volatile("sync;"); + } + + for (int i = 0; i < loop_per_core - 2; ++i) { + if (i % 2 == PING) { + storeOutput(base_addr_output, nram_output, deal_n, total_c, PING, + has_weight, ping_pong_offset, i * deal_num); + coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t, + nram_temp, nram_target, nram_gamma, nram_output, alpha, + compute_num, deal_n, total_c, PONG, ping_pong_offset, + has_weight); + loadInput(nram_input, nram_target, base_addr_input, base_addr_target, + deal_n, total_c, PING, has_weight, ping_pong_offset, + (i + 2) * deal_num); + } else { + storeOutput(base_addr_output, nram_output, deal_n, total_c, PONG, + has_weight, ping_pong_offset, i * deal_num); + coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t, + nram_temp, nram_target, nram_gamma, nram_output, alpha, + compute_num, deal_n, total_c, PING, ping_pong_offset, + has_weight); + loadInput(nram_input, nram_target, base_addr_input, base_addr_target, + deal_n, total_c, PONG, has_weight, ping_pong_offset, + (i + 2) * deal_num); + } + __asm__ volatile("sync;"); + } + + if (loop_per_core > 1) { + if ((loop_per_core - 2) % 2 == PING) { + storeOutput(base_addr_output, nram_output, deal_n, total_c, PING, + has_weight, ping_pong_offset, (loop_per_core - 2) * deal_num); + coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t, + nram_temp, nram_target, nram_gamma, nram_output, alpha, + compute_num, deal_n, total_c, PONG, ping_pong_offset, + has_weight); + } else { + storeOutput(base_addr_output, nram_output, deal_n, total_c, PONG, + has_weight, ping_pong_offset, (loop_per_core - 2) * deal_num); + coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t, + nram_temp, nram_target, nram_gamma, nram_output, alpha, + compute_num, deal_n, total_c, PING, ping_pong_offset, + has_weight); + } + __asm__ volatile("sync;"); + } + + if (loop_per_core > 0) { + if (loop_per_core == 1) { + coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t, + nram_temp, nram_target, nram_gamma, nram_output, alpha, + compute_num, deal_n, total_c, PING, ping_pong_offset, + has_weight); + __asm__ volatile("sync;"); + } + if ((loop_per_core - 1) % 2 == PING) { + storeOutput(base_addr_output, nram_output, deal_n, total_c, PING, + has_weight, ping_pong_offset, (loop_per_core - 1) * deal_num); + } else { + storeOutput(base_addr_output, nram_output, deal_n, total_c, PONG, + has_weight, ping_pong_offset, (loop_per_core - 1) * deal_num); + } + } + + // process the remaining data which N remainder per core is less than deal_n + int32_t rem_for_all = total_num - num_per_core * taskDim; + if (rem_for_all == 0) return; + int32_t rem_n_for_all = rem_for_all / total_c; + int32_t rem_n_per_core = (rem_n_for_all + taskDim - 1) / taskDim; + int32_t rem_num_per_core = rem_n_per_core * total_c; + int32_t rem_num_per_core_align = 0; + int32_t rem_core_num = rem_for_all / rem_num_per_core; + + int32_t rem_n_for_last = rem_n_for_all % rem_n_per_core; + int32_t rem_num_for_last = rem_n_for_last * total_c; + int32_t rem_num_for_last_align = 0; + + if (has_weight) { + int32_t align_c = PAD_UP(total_c, compute_align_num); + rem_num_per_core_align = rem_n_per_core * align_c; + rem_num_for_last_align = rem_n_for_last * align_c; + } else { + rem_num_per_core_align = PAD_UP(rem_num_per_core, compute_align_num); + rem_num_for_last_align = PAD_UP(rem_num_for_last, compute_align_num); + } + + int32_t rem_addr_base = num_per_core * taskDim; + int32_t rem_target_addr_base = loop_per_core * deal_n * taskDim; + base_addr_target = (int32_t *)target + rem_target_addr_base; + base_addr_input = (T *)input + rem_addr_base; + base_addr_output = output + rem_addr_base; + + if (taskId < rem_core_num) { + loadInput(nram_input, nram_target, base_addr_input, base_addr_target, + rem_n_per_core, total_c, PING, has_weight, ping_pong_offset, + taskId * rem_num_per_core); + __asm__ volatile("sync;"); + coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t, + nram_temp, nram_target, nram_gamma, nram_output, alpha, + rem_num_per_core_align, rem_n_per_core, total_c, PING, + ping_pong_offset, has_weight); + __asm__ volatile("sync;"); + storeOutput(base_addr_output, nram_output, rem_n_per_core, total_c, PING, + has_weight, ping_pong_offset, taskId * rem_num_per_core); + } else if (taskId == rem_core_num) { + if (rem_num_for_last == 0) return; + loadInput(nram_input, nram_target, base_addr_input, base_addr_target, + rem_n_for_last, total_c, PING, has_weight, ping_pong_offset, + taskId * rem_num_per_core); + __asm__ volatile("sync;"); + coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t, + nram_temp, nram_target, nram_gamma, nram_output, alpha, + rem_num_for_last_align, rem_n_for_last, total_c, PING, + ping_pong_offset, has_weight); + __asm__ volatile("sync;"); + storeOutput(base_addr_output, nram_output, rem_n_for_last, total_c, PING, + has_weight, ping_pong_offset, taskId * rem_num_per_core); + } else { + return; + } +} + +template +__mlu_global__ void MLUUnion1KernelFocalLossSigmoidBackward( + const void *input, const void *target, const void *weight, + const float gamma, const float alpha, const int32_t total_n, + const int32_t deal_n, const int32_t total_c, void *output) { + focalLossSigmoidBackwardBlock((T *)input, (int32_t *)target, (T *)weight, + gamma, alpha, total_n, deal_n, total_c, + (T *)output); +} +} // namespace backward + +void KernelFocalLossSigmoidForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, + const cnrtDataType_t d_type, + const void *input, const void *target, + const void *weight, const int32_t N, + const int32_t C, const float alpha, + const float gamma, void *output) { + if (d_type == CNRT_FLOAT16) { + forward::MLUUnion1KernelFocalLossSigmoidForward< + half><<>>(input, target, weight, N, C, alpha, + gamma, output); + } else { + forward::MLUUnion1KernelFocalLossSigmoidForward< + float><<>>(input, target, weight, N, C, alpha, + gamma, output); + } +} + +void KernelFocalLossSigmoidBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, + const cnrtDataType_t d_type, + const void *input, const void *target, + const void *weight, const float gamma, + const float alpha, const int32_t dim_n, + const int32_t deal_n, const int32_t dim_c, + void *output) { + if (d_type == CNRT_FLOAT16) { + backward::MLUUnion1KernelFocalLossSigmoidBackward< + half><<>>(input, target, weight, gamma, alpha, + dim_n, deal_n, dim_c, output); + } else { + backward::MLUUnion1KernelFocalLossSigmoidBackward< + float><<>>(input, target, weight, gamma, alpha, + dim_n, deal_n, dim_c, output); + } +} diff --git a/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp b/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp index c7f9f35b7b..4f198ac37b 100644 --- a/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp @@ -10,8 +10,10 @@ using namespace at; #define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_MLU(x) \ + TORCH_CHECK(x.device().type() == at::kMLU, #x " must be a MLU tensor") #define CHECK_CPU(x) \ - TORCH_CHECK(!x.device().is_cuda(), #x " must be a CPU tensor") + TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") #define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CUDA_INPUT(x) \ diff --git a/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp b/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp new file mode 100644 index 0000000000..cd6fc568bb --- /dev/null +++ b/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp @@ -0,0 +1,26 @@ +/************************************************************************* + * Copyright (C) 2021 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef PYTORCH_MLU_HELPER_HPP_ +#define PYTORCH_MLU_HELPER_HPP_ + +#ifdef MMCV_WITH_MLU +#include "aten.h" + +#define NFU_ALIGN_SIZE 128 + +#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y)) + +#define PAD_DOWN(x, y) (((x) / (y)) * (y)) + +#endif + +#endif // PYTORCH_MLU_HELPER_HPP_ diff --git a/mmcv/ops/csrc/pytorch/focal_loss.cpp b/mmcv/ops/csrc/pytorch/focal_loss.cpp index 3e2c92b27a..a0d878ff36 100644 --- a/mmcv/ops/csrc/pytorch/focal_loss.cpp +++ b/mmcv/ops/csrc/pytorch/focal_loss.cpp @@ -1,5 +1,8 @@ // Copyright (c) OpenMMLab. All rights reserved #include "pytorch_cpp_helper.hpp" +#ifdef MMCV_WITH_MLU +#include "pytorch_mlu_helper.hpp" +#endif #ifdef MMCV_WITH_CUDA void SigmoidFocalLossForwardCUDAKernelLauncher(Tensor input, Tensor target, @@ -52,6 +55,31 @@ void softmax_focal_loss_backward_cuda(Tensor input, Tensor target, } #endif +#ifdef MMCV_WITH_MLU +void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target, + Tensor weight, Tensor output, + const float gamma, + const float alpha); + +void SigmoidFocalLossBackwardMLUKernelLauncher(Tensor input, Tensor target, + Tensor weight, Tensor grad_input, + const float gamma, + const float alpha); + +void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target, Tensor weight, + Tensor output, float gamma, float alpha) { + SigmoidFocalLossForwardMLUKernelLauncher(input, target, weight, output, gamma, + alpha); +} + +void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target, Tensor weight, + Tensor grad_input, float gamma, + float alpha) { + SigmoidFocalLossBackwardMLUKernelLauncher(input, target, weight, grad_input, + gamma, alpha); +} +#endif + void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { if (input.device().is_cuda()) { @@ -65,6 +93,12 @@ void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight, alpha); #else AT_ERROR("SigmoidFocalLoss is not compiled with GPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (input.device().type() == at::kMLU) { + CHECK_MLU(input); + CHECK_MLU(target); + sigmoid_focal_loss_forward_mlu(input, target, weight, output, gamma, alpha); #endif } else { AT_ERROR("SigmoidFocalLoss is not implemented on CPU"); @@ -84,6 +118,16 @@ void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight, alpha); #else AT_ERROR("SigmoidFocalLoss is not compiled with GPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (input.device().type() == at::kMLU) { + CHECK_MLU(input); + CHECK_MLU(target); + CHECK_MLU(weight); + CHECK_MLU(grad_input); + + sigmoid_focal_loss_backward_mlu(input, target, weight, grad_input, gamma, + alpha); #endif } else { AT_ERROR("SigmoidFocalLoss is not implemented on CPU"); diff --git a/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp new file mode 100644 index 0000000000..044e8dd011 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp @@ -0,0 +1,315 @@ +/************************************************************************* + * Copyright (C) 2021 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include + +#include "pytorch_mlu_helper.hpp" + +void KernelFocalLossSigmoidForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, + const cnrtDataType_t d_type, + const void *input, const void *target, + const void *weight, const int32_t N, + const int32_t C, const float alpha, + const float gamma, void *output); + +void KernelFocalLossSigmoidBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, + const cnrtDataType_t d_type, + const void *input, const void *target, + const void *weight, const float gamma, + const float alpha, const int32_t dim_n, + const int32_t deal_n, const int32_t dim_c, + void *output); +// Policy Function for Forward +static void policyFuncForward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type, + const Tensor &input, const Tensor &target, + const Tensor &weight) { + auto N = input.size(0); + auto C = input.size(1); + + auto nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); + auto c_align_size = PAD_UP((C * input.itemsize()), NFU_ALIGN_SIZE); + const int split_target_num = 2; + const int split_pipeline_num = 6; + auto scalar_size = NFU_ALIGN_SIZE; + auto weight_size = c_align_size; + const int target_data_width = target.scalar_type() == at::kLong + ? target.itemsize() / 2 + : target.itemsize(); + + // n_seg * c_align_size * split_pipeline_num + + // n_seg * target.itemsize() * split_target_num + + // weight_size + scalar_size <= nram_size + auto n_seg = (nram_size - weight_size - scalar_size) / + (c_align_size * split_pipeline_num + + target_data_width * split_target_num); + auto seg_num = (N + n_seg - 1) / n_seg; + + auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + auto core_num = core_dim * cluster_num; + + k_dim->x = *k_type; + k_dim->y = + seg_num > core_num ? cluster_num : (seg_num + core_dim - 1) / core_dim; + k_dim->z = 1; +} + +// Policy Function for Backward +static void policyFuncBackward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { + // set Union1 Job + *k_type = CNRT_FUNC_TYPE_UNION1; + k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + k_dim->z = 1; +} + +void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target, + Tensor weight, Tensor output, + const float gamma, + const float alpha) { + // params check + TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ", + "But now gamma is ", gamma, "."); + + // check dtype + TORCH_CHECK( + input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, + "Data type of input should be Float or Half. But now input type is ", + input.scalar_type(), "."); + + TORCH_CHECK( + (target.scalar_type() == at::kInt || target.scalar_type() == at::kLong), + "target type should be Int or Long. ", "But now target type is ", + target.scalar_type(), "."); + + if (weight.data_ptr() != nullptr) { + TORCH_CHECK(weight.scalar_type() == input.scalar_type(), + "Data types of input and weight should be the same. But now " + "input type is ", + input.scalar_type(), ", weight type is ", weight.scalar_type(), + "."); + } else { + CNLOG(INFO) << "weight is a empty tensor."; + } + + // check C + auto nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); + auto input_N = input.size(0); + auto input_C = input.size(1); + const int split_target_num = 2; + const int split_pipeline_num = 6; + const int has_weight = (int)(weight.data_ptr() != nullptr); + + // target supports only INT on MLU device while it keeps LONG on host side, + // so target.itemsize() / 2 + const int target_data_width = target.scalar_type() == at::kLong + ? target.itemsize() / 2 + : target.itemsize(); + auto threshold_C = PAD_DOWN((nram_size - NFU_ALIGN_SIZE - + split_target_num * target_data_width) / + (split_pipeline_num + has_weight), + NFU_ALIGN_SIZE) / + input.itemsize(); + + TORCH_CHECK(threshold_C >= input_C, + "input.size(1) should be in the range of [0, ", threshold_C, + "]. ", "But now input.size(1) is ", input_C, "."); + + if (input.numel() == 0 || target.numel() == 0 || output.numel() == 0) { + // return if zero-element + return; + } + + // calculate task dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1; + policyFuncForward(&k_dim, &k_type, input, target, weight); + auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get ptr of tensors + auto input_impl = torch_mlu::getMluTensorImpl(input); + auto input_ptr = input_impl->cnnlMalloc(); + auto target_impl = torch_mlu::getMluTensorImpl(target); + auto target_ptr = target_impl->cnnlMalloc(); + auto weight_impl = torch_mlu::getMluTensorImpl(weight); + auto weight_ptr = weight_impl->cnnlMalloc(); + auto output_impl = torch_mlu::getMluTensorImpl(output); + auto output_ptr = output_impl->cnnlMalloc(); + + // get dtype of input + cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype()); + + CNLOG(INFO) << "Launch Kernel KernelFocalLossSigmoidForward<<>>"; + // launch kernel + KernelFocalLossSigmoidForward(k_dim, k_type, queue, d_type, input_ptr, + target_ptr, weight_ptr, input_N, input_C, alpha, + gamma, output_ptr); +} + +void getDealNAndThresholdC(const int compute_data_bytes, + const int target_data_bytes, const int total_c, + int *deal_n_ptr, int *threshold_c_ptr, + const bool has_weight, const bool is_half) { + /* NRAM partition: + * + * |-----------------ping pong--------------------| + * |input | pt | alpha_t | temp | output | target | flt_min | gamma | weight| + * + * split_pipeline_num is 5: including input, pt, alpha_t, temp, output. + */ + const int nram_split_num = 5; + const int nram_split_pingpong = 2; + const int max_nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); + int32_t compute_align_size = NFU_ALIGN_SIZE; + if (is_half) { + compute_align_size += NFU_ALIGN_SIZE; + } + const int32_t compute_align_num = compute_align_size / compute_data_bytes; + // reservered_align_size: including input(ping pong), pt(ping pong), + // alpha_t(ping pong), temp(ping pong), + // output(ping pong), target(ping pong), + // flt_min and gamma. + const int reservered_align_size = + ((nram_split_num + 1) * nram_split_pingpong + 2) * compute_align_size; + int nram_pingpong_size = max_nram_size - reservered_align_size; + + int compute_c = total_c; + int threshold_c = 0; + if (has_weight) { + // reserved space for weight to align + nram_pingpong_size -= NFU_ALIGN_SIZE; + + // threshold_c * nram_split_pingpong * compute_data_bytes * nram_split_num + + // nram_split_pingpong * target_data_bytes + + // threshold_c * compute_data_bytes <= nram_pingpong_size + threshold_c = + (nram_pingpong_size - nram_split_pingpong * target_data_bytes) / + (compute_data_bytes * (nram_split_num * nram_split_pingpong + 1)); + threshold_c = PAD_DOWN(threshold_c, compute_align_num); + int weight_space = PAD_UP(total_c * compute_data_bytes, NFU_ALIGN_SIZE); + + // reserved space for weight + nram_pingpong_size -= weight_space; + compute_c = PAD_UP(total_c, compute_align_num); + } else { + // threshold_c * nram_split_pingpong * compute_data_bytes * nram_split_num + + // nram_split_pingpong * target_data_bytes <= nram_pingpong_size + threshold_c = + (nram_pingpong_size / nram_split_pingpong - target_data_bytes) / + (nram_split_num * compute_data_bytes); + } + // deal_n * compute_c * nram_split_pingpong * compute_data_bytes * + // nram_split_num + deal_n * nram_split_pingpong * target_data_bytes <= + // nram_pingpong_size + *deal_n_ptr = + nram_pingpong_size / + ((nram_split_num * compute_c * compute_data_bytes + target_data_bytes) * + nram_split_pingpong); + *threshold_c_ptr = threshold_c; +} + +void SigmoidFocalLossBackwardMLUKernelLauncher(Tensor input, Tensor target, + Tensor weight, Tensor output, + const float gamma, + const float alpha) { + // params check + TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ", + "But now gamma is ", gamma, "."); + // check dtype + TORCH_CHECK( + input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, + "Data type of input should be Float or Half. But now input type is ", + input.scalar_type(), "."); + + TORCH_CHECK( + (target.scalar_type() == at::kInt || target.scalar_type() == at::kLong), + "target type should be Int or Long. ", "But now target type is ", + target.scalar_type(), "."); + + bool has_weight = false; + if (weight.data_ptr() != nullptr) { + TORCH_CHECK(weight.scalar_type() == input.scalar_type(), + "Data types of input and weight should be the same. But now " + "input type is ", + input.scalar_type(), ", weight type is ", weight.scalar_type(), + "."); + has_weight = true; + } else { + CNLOG(INFO) << "weight is a empty tensor."; + } + + auto dim_c = input.size(1); + const int compute_data_bytes = sizeof(float); + // target supports only INT on MLU device while it keeps LONG on host side, + // so target.itemsize() / 2 + const int target_data_bytes = target.scalar_type() == at::kLong + ? (target.itemsize() / 2) + : target.itemsize(); + int deal_n = 0; + int threshold_c = 0; + bool is_half = false; + if (input.scalar_type() == at::kHalf) { + is_half = true; + } + // calculate deal_n and threshold_c + getDealNAndThresholdC(compute_data_bytes, target_data_bytes, dim_c, &deal_n, + &threshold_c, has_weight, is_half); + + // check C + TORCH_CHECK(threshold_c >= dim_c, + "input.size(1) should be in the range of [0, ", threshold_c, + "]. ", "But now input.size(1) is ", dim_c, "."); + + if (input.numel() == 0 || target.numel() == 0 || output.numel() == 0) { + // return if zero-element + return; + } + + // set task dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + policyFuncBackward(&k_dim, &k_type); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get ptr of tensors + auto input_impl = torch_mlu::getMluTensorImpl(input); + auto input_ptr = input_impl->cnnlMalloc(); + auto target_impl = torch_mlu::getMluTensorImpl(target); + auto target_ptr = target_impl->cnnlMalloc(); + auto weight_impl = torch_mlu::getMluTensorImpl(weight); + auto weight_ptr = weight_impl->cnnlMalloc(); + auto output_impl = torch_mlu::getMluTensorImpl(output); + auto output_ptr = output_impl->cnnlMalloc(); + + // get dtype of input + cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype()); + auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + auto dim_n = input.size(0); + + CNLOG(INFO) << "Launch Kernel KernelFocalLossSigmoidBackward<<>>"; + + // launch kernel + KernelFocalLossSigmoidBackward(k_dim, k_type, queue, d_type, input_ptr, + target_ptr, weight_ptr, gamma, alpha, dim_n, + deal_n, dim_c, output_ptr); +} diff --git a/mmcv/ops/focal_loss.py b/mmcv/ops/focal_loss.py index 763bc93bd2..8058605161 100644 --- a/mmcv/ops/focal_loss.py +++ b/mmcv/ops/focal_loss.py @@ -34,7 +34,8 @@ def forward(ctx, weight=None, reduction='mean'): - assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) + assert isinstance( + target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor)) assert input.dim() == 2 assert target.dim() == 1 assert input.size(0) == target.size(0) diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 378a006843..2619f2e978 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -44,6 +44,7 @@ PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home) + from .pytorch_wrapper import is_cuda, is_mlu from .registry import Registry, build_from_cfg from .trace import is_jit_tracing __all__ = [ diff --git a/mmcv/utils/pytorch_wrapper.py b/mmcv/utils/pytorch_wrapper.py new file mode 100644 index 0000000000..0462e0d18d --- /dev/null +++ b/mmcv/utils/pytorch_wrapper.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +TORCH_VERSION = torch.__version__ + + +def is_cuda() -> bool: + return torch.cuda.is_available() + + +def is_mlu() -> bool: + if TORCH_VERSION != 'parrots': + try: + return torch.is_mlu_available() + except AttributeError: + return False + return False diff --git a/setup.py b/setup.py index 4dae615217..8a7974ad2b 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,11 @@ from parrots.utils.build_extension import BuildExtension EXT_TYPE = 'parrots' else: - from torch.utils.cpp_extension import BuildExtension + try: + if torch.is_mlu_available(): + from torch_mlu.utils.cpp_extension import BuildExtension + except AttributeError: + from torch.utils.cpp_extension import BuildExtension EXT_TYPE = 'pytorch' cmd_class = {'build_ext': BuildExtension} except ModuleNotFoundError: @@ -262,10 +266,25 @@ def get_extensions(): include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda')) else: - print(f'Compiling {ext_name} without CUDA') - op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') - extension = CppExtension - include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) + try: + if torch.is_mlu_available(): + from torch_mlu.utils.cpp_extension import MLUExtension + define_macros += [('MMCV_WITH_MLU', None)] + mlu_args = os.getenv('MMCV_MLU_ARGS') + extra_compile_args['cncc'] = [mlu_args] if mlu_args else [] + op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + op_files += glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + op_files += glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu') + extension = MLUExtension + include_dirs.append( + os.path.abspath('./mmcv/ops/csrc/common')) + else: + print('Cambricon Catch is not available!') + except AttributeError: + print(f'Compiling {ext_name} without CUDA') + op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + extension = CppExtension + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) ext_ops = extension( name=ext_name, diff --git a/tests/test_ops/test_focal_loss.py b/tests/test_ops/test_focal_loss.py index e52f060f6a..f893da8e1b 100644 --- a/tests/test_ops/test_focal_loss.py +++ b/tests/test_ops/test_focal_loss.py @@ -1,6 +1,9 @@ import numpy as np +import pytest import torch +from mmcv.utils import is_cuda, is_mlu + _USING_PARROTS = True try: from parrots.autograd import gradcheck @@ -56,9 +59,7 @@ def _test_softmax(self, dtype=torch.float): assert np.allclose(loss.data.cpu().numpy(), output[0], 1e-2) assert np.allclose(x.grad.data.cpu(), np_x_grad, 1e-2) - def _test_sigmoid(self, dtype=torch.float): - if not torch.cuda.is_available(): - return + def _test_sigmoid(self, device, dtype=torch.float): from mmcv.ops import sigmoid_focal_loss alpha = 0.25 gamma = 2.0 @@ -67,9 +68,9 @@ def _test_sigmoid(self, dtype=torch.float): np_y = np.array(case[1]) np_x_grad = np.array(output[1]) - x = torch.from_numpy(np_x).cuda().type(dtype) + x = torch.from_numpy(np_x).to(device).type(dtype) x.requires_grad_() - y = torch.from_numpy(np_y).cuda().long() + y = torch.from_numpy(np_y).to(device).long() loss = sigmoid_focal_loss(x, y, gamma, alpha, None, 'mean') loss.backward() @@ -127,11 +128,31 @@ def test_softmax_float(self): def test_softmax_half(self): self._test_softmax(dtype=torch.half) - def test_sigmoid_float(self): - self._test_sigmoid(dtype=torch.float) - - def test_sigmoid_half(self): - self._test_sigmoid(dtype=torch.half) + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not is_cuda(), reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not is_mlu(), reason='requires MLU support')) + ]) + def test_sigmoid_float(self, device): + self._test_sigmoid(device=device, dtype=torch.float) + + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not is_cuda(), reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not is_mlu(), reason='requires MLU support')) + ]) + def test_sigmoid_half(self, device): + self._test_sigmoid(device, dtype=torch.half) def test_grad_softmax_float(self): self._test_grad_softmax(dtype=torch.float) From 0900a631d47b11a7f7878b64fdb22e46e12f7bda Mon Sep 17 00:00:00 2001 From: zihanchang11 <92860914+zihanchang11@users.noreply.github.com> Date: Mon, 22 Nov 2021 11:50:31 +0800 Subject: [PATCH 15/30] [Feature] Support RoiAlign With Cambricon MLU Backend (#1429) --- .../csrc/common/mlu/roi_align_mlu_kernel.mlu | 583 ++++++++++++++++++ mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp | 171 +++++ mmcv/ops/csrc/pytorch/roi_align.cpp | 58 ++ tests/test_ops/test_roi_align.py | 29 +- 4 files changed, 834 insertions(+), 7 deletions(-) create mode 100644 mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu create mode 100644 mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp diff --git a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu new file mode 100644 index 0000000000..e11aa4c575 --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu @@ -0,0 +1,583 @@ +/************************************************************************* + * Copyright (C) 2021 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "common_mlu_helper.hpp" + +__nram__ char buffer[MAX_NRAM_SIZE]; + +#define ALIGN_SIZE 64 +#define MAX_ELEMENTS_FLOAT (50 * 1024) +#define MAX_ELEMENTS_HALF (100 * 1024) +#define ROI_OFFSET 5 +#define SAMPLING_NUM 4 + +#define DIM_BOX 5 +#define BLOCK_INPUT_OUTPUT 2 + +namespace forward { +template +__mlu_func__ void bilinearInterpolate(int input_height, int input_width, + float y, float x, T *w1, T *w2, T *w3, + T *w4, int *x_low, int *x_high, + int *y_low, int *y_high, int *empty, + T zero_sign) { + // deal with cases that inverse elements are of feature map boundary + if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { + *empty = 1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + *y_low = int(y); + *x_low = int(x); + + if (*y_low >= input_height - 1) { + *y_high = *y_low = input_height - 1; + y = (T)(*y_low); + } else { + *y_high = *y_low + 1; + } + + if (*x_low >= input_width - 1) { + *x_high = *x_low = input_width - 1; + x = (T)(*x_low); + } else { + *x_high = *x_low + 1; + } + + T ly = y - *y_low; + T lx = x - *x_low; + T hy = 1.0 - ly; + T hx = 1.0 - lx; + + *w1 = hy * hx * zero_sign; + *w2 = hy * lx * zero_sign; + *w3 = ly * hx * zero_sign; + *w4 = ly * lx * zero_sign; + + return; +} + +template +__mlu_func__ void roialignForwardKernel( + T *input, T *rois, T *output, const bool aligned, const int channels, + const int pooled_height, const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, const float spatial_scale, + const int num_rois, const int max_elements) { + /* + * NRAM partition + * |----------------------NRAM------ -----------------| + * | | + * | output | + * |--------------------------------------------------| + * | | + * | input | + * | | + * |--------------------------------------------------| + * | rois(batch_id, x1, y1, x2, y2) | + * |--------------------------------------------------| + * + * channel data will loop inside of input_nram, when channel * size(T) > + * input_nram + */ + + int channel_align = PAD_UP(channels, ALIGN_SIZE); + int height = 0; + int width = 0; + int samp_channel_align = channel_align * SAMPLING_NUM; + int samp_channel = channels * SAMPLING_NUM; + + // multi-core params + int inter_num = num_rois / taskDim; + int rem_num = num_rois % taskDim; + int offset_length; + int task_length; + + // the length dealt by every core and the offset of taskid + if (taskId < rem_num) { + task_length = inter_num + 1; + offset_length = taskId * (inter_num + 1); + } else { + task_length = inter_num; + offset_length = rem_num * (inter_num + 1) + (taskId - rem_num) * inter_num; + } + + int max_size = max_elements >> 1; + T *nram_out = (T *)buffer; + T *nram_in = nram_out + max_size; + T *nram_rois = nram_in + max_elements; + + int pooled_size = pooled_height * pooled_width; + // output and roi data ptr + T *top_data = output + offset_length * pooled_size * channels; + T *task_rois = rois + offset_length * ROI_OFFSET; + + for (int roi_id = 0; roi_id < task_length; roi_id++) { + // For each roi, find the corresponding feature map which it belongs to, + // and compute the scaling_factor to map it to that feature map. + height = input_height; + width = input_width; + T offset = aligned ? (T)0.5 : (T)0; + + T *roi_id_tmp = task_rois + roi_id * ROI_OFFSET; + __bang_write_zero(nram_rois, ALIGN_SIZE); + __memcpy((void *)nram_rois, (void *)roi_id_tmp, ROI_OFFSET * sizeof(T), + GDRAM2NRAM); + + int batch_id = nram_rois[0]; + T roi_xmin = nram_rois[1]; + T roi_ymin = nram_rois[2]; + T roi_xmax = nram_rois[3]; + T roi_ymax = nram_rois[4]; + + roi_xmin = roi_xmin * spatial_scale - offset; + roi_ymin = roi_ymin * spatial_scale - offset; + roi_xmax = roi_xmax * spatial_scale - offset; + roi_ymax = roi_ymax * spatial_scale - offset; + + float roi_width = roi_xmax - roi_xmin; + float roi_height = roi_ymax - roi_ymin; + + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = roi_width > 1 ? roi_width : 1.0; + roi_height = roi_height > 1 ? roi_height : 1.0; + } + + float bin_size_h = (float)roi_height / pooled_height; + float bin_size_w = (float)roi_width / pooled_width; + + // input data ptr + T *offset_bottom_data = input + batch_id * channels * width * height; + T *tmp_sum = nram_out; + __bang_write_zero(nram_out, max_size); + + // We use roi_bin_grid to sample the grid, and perform average pooling + // inside a bin. When the grid is empty, then output zeros. + int roi_bin_grid_h = + (sampling_ratio > 0) ? sampling_ratio : __float2int_up(bin_size_h); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : __float2int_up(bin_size_w); + float count = roi_bin_grid_h * roi_bin_grid_w; + float zero_sign_tmp = 1.0f / count; + + for (int ph = 0; ph < pooled_height; ph++) { + float y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid + for (int pw = 0; pw < pooled_width; pw++) { + float x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid + // Bilinear interpolatation + if (samp_channel_align < max_elements) { + // One aligned channel data can be computed at one time + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + float y = + (y_pre + ((iy + 0.5) * bin_size_h) / (roi_bin_grid_h)) <= 0 + ? 0 + : (y_pre + + ((iy + 0.5) * bin_size_h) / + (roi_bin_grid_h)); // center_point y + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + float x = + (x_pre + ((ix + 0.5) * bin_size_w) / (roi_bin_grid_w)) <= 0 + ? 0 + : (x_pre + + ((ix + 0.5) * bin_size_w) / + (roi_bin_grid_w)); // center_point x + T zero_sign = + (T)(x >= -1.0 && x <= width && y >= -1.0 && y <= height) * + zero_sign_tmp; + + int empty = 0; + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, + &w3, &w4, &x_low, &x_high, &y_low, &y_high, + &empty, zero_sign); + + // load + int cpy_len = (x_high - x_low) * channels; + int cpy_size = channels * sizeof(T); + + int offset1 = (y_low * width + x_low) * channels; + int offset2 = (y_high * width + x_low) * channels; + + T *tmp1 = offset_bottom_data + offset1; + T *tmp2 = offset_bottom_data + offset2; + + T *tmp_cyc1 = nram_in; + T *tmp_cyc2 = nram_in + channel_align; + T *tmp_cyc3 = nram_in + channel_align * 2; + T *tmp_cyc4 = nram_in + channel_align * 3; + __asm__ volatile("sync;"); + if (empty == 1) { + __nramset(nram_in, channel_align, T(0)); + } else { + // load gdram to nram + __memcpy_async(tmp_cyc1, tmp1, cpy_size, GDRAM2NRAM); + __memcpy_async(tmp_cyc2, tmp1 + cpy_len, cpy_size, GDRAM2NRAM); + __memcpy_async(tmp_cyc3, tmp2, cpy_size, GDRAM2NRAM); + __memcpy_async(tmp_cyc4, tmp2 + cpy_len, cpy_size, GDRAM2NRAM); + __asm__ volatile("sync;"); + // roialign_forward compute + __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, channel_align); + __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, channel_align); + __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, channel_align); + __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, channel_align); + __bang_sumpool(nram_in, nram_in, channel_align, 1, SAMPLING_NUM, + 1, SAMPLING_NUM, 1, 1); + } + __bang_add(tmp_sum, tmp_sum, nram_in, channel_align); + } + } + } else { + // One aligned channel data cannot be computed at one time + int cyc_num = samp_channel / max_elements + + (int)(samp_channel % max_elements != 0); + int cyc_channel = max_elements / SAMPLING_NUM; + for (int i = 0; i < cyc_num; ++i) { + int real_channel = + (i == cyc_num - 1) ? channels - i * cyc_channel : cyc_channel; + int align_channel = + (i == cyc_num - 1) + ? PAD_UP((channel_align - i * cyc_channel), ALIGN_SIZE) + : cyc_channel; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + float y = + (y_pre + ((iy + 0.5) * bin_size_h) / (roi_bin_grid_h)) <= 0 + ? 0 + : (y_pre + + ((iy + 0.5) * bin_size_h) / + (roi_bin_grid_h)); // center_point y + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + float x = + (x_pre + ((ix + 0.5) * bin_size_w) / (roi_bin_grid_w)) <= 0 + ? 0 + : (x_pre + + ((ix + 0.5) * bin_size_w) / + (roi_bin_grid_w)); // center_point x + + T zero_sign = + (T)(x >= -1.0 && x <= width && y >= -1.0 && y <= height) * + zero_sign_tmp; + + int empty = 0; + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, + &w3, &w4, &x_low, &x_high, &y_low, &y_high, + &empty, zero_sign); + + // load + int cpy_len = (x_high - x_low) * channels; + + int offset1 = (y_low * width + x_low) * channels; + int offset2 = (y_high * width + x_low) * channels; + + T *tmp1 = offset_bottom_data + offset1 + cyc_channel * i; + T *tmp2 = offset_bottom_data + offset2 + cyc_channel * i; + + T *tmp_cyc1 = nram_in; + T *tmp_cyc2 = nram_in + cyc_channel; + T *tmp_cyc3 = nram_in + cyc_channel * 2; + T *tmp_cyc4 = nram_in + cyc_channel * 3; + __asm__ volatile("sync;"); + if (empty == 1) { // exits abnormal values + __nramset(nram_in, align_channel, T(0)); + } else { + __memcpy_async(tmp_cyc1, tmp1, align_channel * sizeof(T), + GDRAM2NRAM); + __memcpy_async(tmp_cyc2, tmp1 + cpy_len, + align_channel * sizeof(T), GDRAM2NRAM); + __memcpy_async(tmp_cyc3, tmp2, align_channel * sizeof(T), + GDRAM2NRAM); + __memcpy_async(tmp_cyc4, tmp2 + cpy_len, + align_channel * sizeof(T), GDRAM2NRAM); + __asm__ volatile("sync;"); + __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, align_channel); + __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, align_channel); + __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, align_channel); + __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, align_channel); + __bang_sumpool(nram_in, nram_in, cyc_channel, 1, SAMPLING_NUM, + 1, SAMPLING_NUM, 1, 1); + } + __bang_add(tmp_sum, tmp_sum, nram_in, align_channel); + } + } + __memcpy(top_data + cyc_channel * i, tmp_sum, + real_channel * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + } + // copy output data to ddr when channel num is not aligned with 64 + if (samp_channel_align < max_elements) { + __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + top_data += channels; + } // loop for pw + } // loop for ph + } // loop for num_roi +} + +__mlu_global__ void MLUUnion1KernelRoialign( + const void *input, const void *rois, const int channels, const bool aligned, + const int pooled_height, const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, const float spatial_scale, + const int num_rois, const cnrtDataType_t data_type, void *output) { + int max_elements = + (data_type == CNRT_FLOAT32) ? MAX_ELEMENTS_FLOAT : MAX_ELEMENTS_HALF; + switch (data_type) { + case CNRT_FLOAT16: { + roialignForwardKernel((half *)input, (half *)rois, (half *)output, + aligned, channels, pooled_height, pooled_width, + input_height, input_width, sampling_ratio, + spatial_scale, num_rois, max_elements); + }; break; + case CNRT_FLOAT32: { + roialignForwardKernel((float *)input, (float *)rois, (float *)output, + aligned, channels, pooled_height, pooled_width, + input_height, input_width, sampling_ratio, + spatial_scale, num_rois, max_elements); + }; break; + default: + break; + } + return; +} +} // namespace forward + +namespace backward { +__mlu_func__ void bilinearInterpolateGradient(int height, int width, float y, + float x, float *w1, float *w2, + float *w3, float *w4, int *x_low, + int *x_high, int *y_low, + int *y_high) { + if (y < -1.0 || y > height || x < -1.0 || x > width) { + *w1 = 0.0, *w2 = 0.0, *w3 = 0.0, *w4 = 0.0; + *x_low = -1, *x_high = -1, *y_low = -1, *y_high = -1; + return; + } + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + *y_low = (int)y; + *x_low = (int)x; + if (*y_low >= height - 1) { + *y_high = height - 1, *y_low = height - 1; + y = (float)(*y_low); + } else { + *y_high = *y_low + 1; + } + if (*x_low >= width - 1) { + *x_high = width - 1, *x_low = width - 1; + x = (float)(*x_low); + } else { + *x_high = *x_low + 1; + } + float ly = y - *y_low, lx = x - *x_low; + float hy = 1.0 - ly, hx = 1.0 - lx; + *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; + return; +} + +template +__mlu_func__ void unionRoiAlignBp( + T *grads, T *boxes, T *grads_image, const int boxes_num, const int hi, + const int wi, const int c, const int no, const int ho, const int wo, + const float spatial_scale, const int sampling_ratio, const bool aligned) { + int c_align = PAD_UP(c, NFU_ALIGN_SIZE / sizeof(T)); + int deal_this_core = + boxes_num / taskDim + (int)(taskId < boxes_num % taskDim); + for (int i = 0; i < deal_this_core; ++i) { + int box_id = i * taskDim + taskId; + T *box = boxes + box_id * DIM_BOX; + T *grads_offset = grads + box_id * hi * wi * c; + int image_id = (int)box[0]; + T *image_offset = grads_image + image_id * ho * wo * c; + + float offset = aligned ? 0.5 : 0.0; + float x1 = box[1] * spatial_scale - offset; + float y1 = box[2] * spatial_scale - offset; + float x2 = box[3] * spatial_scale - offset; + float y2 = box[4] * spatial_scale - offset; + float roi_width = x2 - x1; + float roi_height = y2 - y1; + if (!aligned) { + roi_width = (roi_width > 1.0) ? roi_width : 1.0; + roi_height = (roi_height > 1.0) ? roi_height : 1.0; + } + float bin_size_h = roi_height / hi; + float bin_size_w = roi_width / wi; + + int roi_grid_h = + (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_height / hi); + int roi_grid_w = + (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / wi); + const int count = roi_grid_h * roi_grid_w; + if (c_align * sizeof(T) * BLOCK_INPUT_OUTPUT <= MAX_NRAM_SIZE) { + for (int ih = 0; ih < hi; ++ih) { + for (int iw = 0; iw < wi; ++iw) { + T *grads_ = grads_offset + ih * wi * c + iw * c; + for (int iy = 0; iy < roi_grid_h; ++iy) { + const float y = + y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; + for (int ix = 0; ix < roi_grid_w; ++ix) { + const float x = + x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; + float w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, + &x_low, &x_high, &y_low, &y_high); + if (x_low >= 0 && y_low >= 0) { + __memcpy(buffer, grads_, c * sizeof(T), GDRAM2NRAM); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, + (T)(w1 / count), c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_low * wo * c + x_low * c, + (T *)buffer + c_align, c); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, + (T)(w2 / count), c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_low * wo * c + x_high * c, + (T *)buffer + c_align, c); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, + (T)(w3 / count), c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_high * wo * c + x_low * c, + (T *)buffer + c_align, c); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, + (T)(w4 / count), c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_high * wo * c + x_high * c, + (T *)buffer + c_align, c); + } // x_low && y_low + } // ix + } // iy + } // iw + } // ih + } else { + for (int ih = 0; ih < hi; ++ih) { + for (int iw = 0; iw < wi; ++iw) { + T *grads_ = grads_offset + ih * wi * c + iw * c; + for (int iy = 0; iy < roi_grid_h; ++iy) { + const float y = + y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; + for (int ix = 0; ix < roi_grid_w; ++ix) { + const float x = + x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; + float w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, + &x_low, &x_high, &y_low, &y_high); + if (x_low >= 0 && y_low >= 0) { + int deal_once = PAD_DOWN(MAX_NRAM_SIZE / BLOCK_INPUT_OUTPUT, + NFU_ALIGN_SIZE) / + sizeof(T); + int c_repeat = c / deal_once + (int)(c % deal_once != 0); + for (int i = 0; i < c_repeat; ++i) { + int deal_c = deal_once; + int align_c = deal_once; + if (i == c_repeat - 1) { + deal_c = c - i * deal_once; + align_c = c_align - i * deal_once; + } + __memcpy(buffer, grads_ + i * deal_once, deal_c * sizeof(T), + GDRAM2NRAM); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, + (T)(w1 / count), align_c); + __bang_atomic_add( + (T *)buffer + align_c, + image_offset + y_low * wo * c + x_low * c + i * deal_once, + (T *)buffer + align_c, deal_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, + (T)(w2 / count), align_c); + __bang_atomic_add((T *)buffer + align_c, + image_offset + y_low * wo * c + x_high * c + + i * deal_once, + (T *)buffer + align_c, deal_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, + (T)(w3 / count), align_c); + __bang_atomic_add((T *)buffer + align_c, + image_offset + y_high * wo * c + x_low * c + + i * deal_once, + (T *)buffer + align_c, deal_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, + (T)(w4 / count), align_c); + __bang_atomic_add((T *)buffer + align_c, + image_offset + y_high * wo * c + + x_high * c + i * deal_once, + (T *)buffer + align_c, deal_c); + } // for c_repeat + } // x_low >= 0 && y_low >= 0 + } // ix + } // iy + } // iw + } // ih + } // if c + } // i +} + +__mlu_global__ void MLUUnion1KernelRoiAlignBackward( + const void *grads, const void *boxes, void *grads_image, + const cnrtDataType_t dtype, const int boxes_num, const int hi, const int wi, + const int c, const int no, const int ho, const int wo, + const float spatial_scale, const int sampling_ratio, const bool aligned) { + // make sure that memcore is not used + if (coreId == 0x80) { + return; + } + switch (dtype) { + case CNRT_FLOAT16: { + unionRoiAlignBp((half *)grads, (half *)boxes, (half *)grads_image, + boxes_num, hi, wi, c, no, ho, wo, spatial_scale, + sampling_ratio, aligned); + }; break; + case CNRT_FLOAT32: { + unionRoiAlignBp((float *)grads, (float *)boxes, (float *)grads_image, + boxes_num, hi, wi, c, no, ho, wo, spatial_scale, + sampling_ratio, aligned); + }; break; + default: { return; } + } +} +} // namespace backward + +void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t d_type, + const void *input, const void *rois, const int channels, + const bool aligned, const int pooled_height, + const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, + const float spatial_scale, const int num_rois, + void *output) { + forward::MLUUnion1KernelRoialign<<>>( + input, rois, channels, aligned, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, d_type, output); +} + +void KernelRoiAlignBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t dtype, + const void *grads, const void *boxes, + void *grads_image, const int boxes_num, + const int hi, const int wi, const int c, + const int no, const int ho, const int wo, + const float spatial_scale, const int sampling_ratio, + const bool aligned) { + backward::MLUUnion1KernelRoiAlignBackward<<>>( + grads, boxes, grads_image, dtype, boxes_num, hi, wi, c, no, ho, wo, + spatial_scale, sampling_ratio, aligned); +} diff --git a/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp new file mode 100644 index 0000000000..1c23db2bce --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp @@ -0,0 +1,171 @@ +/************************************************************************* + * Copyright (C) 2021 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "pytorch_mlu_helper.hpp" + +void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t d_type, + const void *input, const void *rois, const int channels, + const bool aligned, const int pooled_height, + const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, + const float spatial_scale, const int num_rois, + void *output); + +void KernelRoiAlignBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t dtype, + const void *grads, const void *boxes, + void *grads_image, const int boxes_num, + const int hi, const int wi, const int c, + const int no, const int ho, const int wo, + const float spatial_scale, const int sampling_ratio, + const bool aligned); + +void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned) { + // params check + TORCH_CHECK( + input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, + "input type should be Float or Half, got ", input.scalar_type()); + TORCH_CHECK(rois.scalar_type() == input.scalar_type(), + "rois should have the same type as input"); + TORCH_CHECK(input.dim() == 4, "input should be a 4d tensor, got ", + input.dim(), "D"); + TORCH_CHECK(rois.dim() == 2, "rois should be a 2d tensor, got ", rois.dim(), + "D"); + TORCH_CHECK(pool_mode == 1, "pool_mode only suppurts 'avg' currently"); + + auto memory_format = + torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); + auto input_tensor = + torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + if (output.numel() == 0) { + output = at::zeros({num_rois, channels, aligned_height, aligned_width}, + input.options()); + return; + } + + at::Tensor output_tmp = + at::empty({num_rois, channels, aligned_height, aligned_width}, + input.options(), memory_format); + + // get tensor impl + auto self_impl = torch_mlu::getMluTensorImpl(input_tensor); + auto rois_impl = torch_mlu::getMluTensorImpl(rois); + auto output_impl = torch_mlu::getMluTensorImpl(output_tmp); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get the mlu ptr + auto self_ptr = self_impl->cnnlMalloc(); + auto rois_ptr = rois_impl->cnnlMalloc(); + auto output_ptr = output_impl->cnnlMalloc(); + + cnrtJobType_t k_type = CNRT_FUNC_TYPE_UNION1; + cnrtDim3_t k_dim; + k_dim.x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + k_dim.y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + k_dim.z = 1; + cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input.dtype()); + + KernelRoiAlign(k_dim, k_type, queue, data_type, self_ptr, rois_ptr, channels, + aligned, aligned_height, aligned_width, height, width, + sampling_ratio, spatial_scale, num_rois, output_ptr); + + output.copy_(output_tmp); +} + +static int nearestPower2(int x) { + x--; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + x++; + return x; +} + +void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois, + Tensor argmax_y, Tensor argmax_x, + Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned) { + // params check + TORCH_CHECK( + grad.scalar_type() == at::kFloat || grad.scalar_type() == at::kHalf, + "grad type should be Float or Half, got ", grad.scalar_type()); + TORCH_CHECK(rois.scalar_type() == grad.scalar_type(), + "rois should have the same type as grad"); + TORCH_CHECK(grad.dim() == 4, "grad should be a 4d tensor, got ", grad.dim(), + "D"); + TORCH_CHECK(rois.dim() == 2, "rois should be a 2d tensor, got ", rois.dim(), + "D"); + TORCH_CHECK(pool_mode == 1, "pool_mode only suppurts 'avg' currently"); + + int batch_size = grad_input.size(0); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + auto memory_format = + torch_mlu::cnnl::ops::get_channels_last_memory_format(grad.dim()); + auto grad_ = torch_mlu::cnnl::ops::cnnl_contiguous(grad, memory_format); + auto grad_input_ = at::empty({batch_size, channels, height, width}, + grad.options(), memory_format) + .zero_(); + + int boxes_num = rois.size(0); + int hi = grad.size(2); + int wi = grad.size(3); + int c = grad.size(1); + + int no = grad_input.size(0); + int ho = grad_input.size(2); + int wo = grad_input.size(3); + + // get tensor impl + auto grad_impl = torch_mlu::getMluTensorImpl(grad_); + auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_); + auto rois_impl = torch_mlu::getMluTensorImpl(rois); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get the mlu ptr + auto grad_ptr = grad_impl->cnnlMalloc(); + auto rois_ptr = rois_impl->cnnlMalloc(); + auto grad_input_ptr = grad_input_impl->cnnlMalloc(); + + cnrtJobType_t k_type = CNRT_FUNC_TYPE_UNION1; + int need_core = nearestPower2(boxes_num); + int union_number = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + uint32_t dim_x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + uint32_t dim_y = (need_core - 1) / dim_x + 1; + dim_y = (dim_y > union_number) ? union_number : dim_y; + cnrtDim3_t k_dim = {dim_x, dim_y, 1}; + cnrtDataType_t k_dtype = torch_mlu::toCnrtDtype(grad.dtype()); + + KernelRoiAlignBackward(k_dim, k_type, queue, k_dtype, grad_ptr, rois_ptr, + grad_input_ptr, boxes_num, hi, wi, c, no, ho, wo, + spatial_scale, sampling_ratio, aligned); + grad_input.copy_(grad_input_); +} diff --git a/mmcv/ops/csrc/pytorch/roi_align.cpp b/mmcv/ops/csrc/pytorch/roi_align.cpp index b44a742ceb..5c3a1dd30d 100644 --- a/mmcv/ops/csrc/pytorch/roi_align.cpp +++ b/mmcv/ops/csrc/pytorch/roi_align.cpp @@ -36,6 +36,40 @@ void roi_align_backward_cuda(Tensor grad_output, Tensor rois, Tensor argmax_y, } #endif +#ifdef MMCV_WITH_MLU +void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned); + +void ROIAlignBackwardMLUKernelLauncher(Tensor grad_output, Tensor rois, + Tensor argmax_y, Tensor argmax_x, + Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned); + +void roi_align_forward_mlu(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, bool aligned) { + ROIAlignForwardMLUKernelLauncher(input, rois, output, argmax_y, argmax_x, + aligned_height, aligned_width, spatial_scale, + sampling_ratio, pool_mode, aligned); +} + +void roi_align_backward_mlu(Tensor grad_output, Tensor rois, Tensor argmax_y, + Tensor argmax_x, Tensor grad_input, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned) { + ROIAlignBackwardMLUKernelLauncher( + grad_output, rois, argmax_y, argmax_x, grad_input, aligned_height, + aligned_width, spatial_scale, sampling_ratio, pool_mode, aligned); +} +#endif + void ROIAlignForwardCPULauncher(Tensor input, Tensor rois, Tensor output, Tensor argmax_y, Tensor argmax_x, int aligned_height, int aligned_width, @@ -85,6 +119,18 @@ void roi_align_forward(Tensor input, Tensor rois, Tensor output, sampling_ratio, pool_mode, aligned); #else AT_ERROR("RoIAlign is not compiled with GPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (input.device().type() == at::kMLU) { + CHECK_MLU(input); + CHECK_MLU(rois); + CHECK_MLU(output); + CHECK_MLU(argmax_y); + CHECK_MLU(argmax_x); + + roi_align_forward_mlu(input, rois, output, argmax_y, argmax_x, + aligned_height, aligned_width, spatial_scale, + sampling_ratio, pool_mode, aligned); #endif } else { CHECK_CPU_INPUT(input); @@ -115,6 +161,18 @@ void roi_align_backward(Tensor grad_output, Tensor rois, Tensor argmax_y, sampling_ratio, pool_mode, aligned); #else AT_ERROR("RoIAlign is not compiled with GPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (grad_output.device().type() == at::kMLU) { + CHECK_MLU(grad_output); + CHECK_MLU(rois); + CHECK_MLU(argmax_y); + CHECK_MLU(argmax_x); + CHECK_MLU(grad_input); + + roi_align_backward_mlu(grad_output, rois, argmax_y, argmax_x, grad_input, + aligned_height, aligned_width, spatial_scale, + sampling_ratio, pool_mode, aligned); #endif } else { CHECK_CPU_INPUT(grad_output); diff --git a/tests/test_ops/test_roi_align.py b/tests/test_ops/test_roi_align.py index db7c037401..9b0c94d694 100644 --- a/tests/test_ops/test_roi_align.py +++ b/tests/test_ops/test_roi_align.py @@ -2,6 +2,8 @@ import pytest import torch +from mmcv.utils import is_cuda, is_mlu + _USING_PARROTS = True try: from parrots.autograd import gradcheck @@ -10,6 +12,7 @@ _USING_PARROTS = False # yapf:disable + inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]), ([[[[1., 2.], [3., 4.]], @@ -38,8 +41,6 @@ def _test_roialign_gradcheck(device, dtype): - if not torch.cuda.is_available() and device == 'cuda': - pytest.skip('test requires GPU') try: from mmcv.ops import RoIAlign except ModuleNotFoundError: @@ -64,8 +65,6 @@ def _test_roialign_gradcheck(device, dtype): def _test_roialign_allclose(device, dtype): - if not torch.cuda.is_available() and device == 'cuda': - pytest.skip('test requires GPU') try: from mmcv.ops import roi_align except ModuleNotFoundError: @@ -74,7 +73,6 @@ def _test_roialign_allclose(device, dtype): pool_w = 2 spatial_scale = 1.0 sampling_ratio = 2 - for case, output in zip(inputs, outputs): np_input = np.array(case[0]) np_rois = np.array(case[1]) @@ -94,8 +92,25 @@ def _test_roialign_allclose(device, dtype): x.grad.data.type(torch.float).cpu().numpy(), np_grad, atol=1e-3) -@pytest.mark.parametrize('device', ['cuda', 'cpu']) -@pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half]) +@pytest.mark.parametrize('device', [ + 'cpu', + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not is_cuda(), reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif(not is_mlu(), reason='requires MLU support')) +]) +@pytest.mark.parametrize('dtype', [ + torch.float, + pytest.param( + torch.double, + marks=pytest.mark.skipif( + is_mlu(), + reason='MLU does not support for 64-bit floating point')), + torch.half +]) def test_roialign(device, dtype): # check double only if dtype is torch.double: From 693de239bf12aeb155f1c8c16e4435b781defffb Mon Sep 17 00:00:00 2001 From: shlrao Date: Mon, 22 Nov 2021 14:53:01 +0800 Subject: [PATCH 16/30] [Feature] Support NMS with cambricon MLU backend (#1467) --- mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu | 649 ++++++++++++++++++++ mmcv/ops/csrc/common/pytorch_cpp_helper.hpp | 3 + mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp | 96 +++ mmcv/ops/csrc/pytorch/nms.cpp | 15 + tests/test_ops/test_nms.py | 18 +- 5 files changed, 777 insertions(+), 4 deletions(-) create mode 100644 mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu create mode 100644 mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp diff --git a/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu new file mode 100644 index 0000000000..1095da870c --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu @@ -0,0 +1,649 @@ +/************************************************************************* + * Copyright (C) 2021 by Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "common_mlu_helper.hpp" + +#define NMS_SIZE (64) +#define COORD_DIM (4) +#define MEMORY_CORE (0x80) +#define INFO_NUM (5) // 5 means x1, x2, y1, y2 and score + +#define SIZE_NRAM_BUF (MAX_NRAM_SIZE + REM_FOR_STACK - 62 * 1024) +#define SIZE_SRAM_BUF (MAX_SRAM_SIZE) + +__nram__ int8_t nram_buffer[SIZE_NRAM_BUF]; +__mlu_shared__ int8_t sram_buffer[SIZE_SRAM_BUF]; + +__mlu_func__ void pvLock() { +#if __BANG_ARCH__ == 270 + if (coreId != MEMORY_CORE) { + __bang_lock(0, 0); + } +#endif +} + +__mlu_func__ void pvUnlock() { +#if __BANG_ARCH__ == 270 + if (coreId != MEMORY_CORE) { + __bang_unlock(0, 0); + } +#endif +} + +enum Addr { SRAM, GDRAM }; + +template +__mlu_func__ void nms_detection( + uint32_t *output_box_num, const int output_mode, const int input_layout, + OUT_DT *output_data, const Addr dst, IN_DT *input_data_score, + const IN_DT *input_data_box, const Addr src, IN_DT *buffer, + const int buffer_size, IN_DT *sram, const int core_limit, + const int input_box_num, const int input_stride, const int output_stride, + const int keepNum, const float thresh_iou, const float thresh_score, + const float offset, const int algo) { + // global value, it is stored in sram with a offset from the begin. + const int flag_offset_size = 28; + int32_t *loop_end_flag = (int32_t *)(sram + flag_offset_size); + loop_end_flag[0] = 0; + // score, x1, y1, x2, y2, inter_x1, inter_y1, inter_x2, inter_y2 + const int nms_buffer_count1 = 9; + // temp nram buffer to store selected target. + const int nram_save_limit_count = 256; + float div_thresh_iou = 1.0 / thresh_iou; + + // input data ptr + IN_DT *input_score_ptr; + const IN_DT *input_x1_ptr; + const IN_DT *input_y1_ptr; + const IN_DT *input_x2_ptr; + const IN_DT *input_y2_ptr; + input_score_ptr = input_data_score; + input_x1_ptr = input_data_box; + if (input_layout == 0) { + // [boxes_num, 4] + input_y1_ptr = input_x1_ptr + 1; + input_x2_ptr = input_x1_ptr + 2; + input_y2_ptr = input_x1_ptr + 3; + } else if (input_layout == 1) { + // [4, boxes_num] + input_y1_ptr = input_x1_ptr + input_stride; + input_x2_ptr = input_y1_ptr + input_stride; + input_y2_ptr = input_x2_ptr + input_stride; + } + + // nram data ptr + IN_DT *x1; + IN_DT *y1; + IN_DT *x2; + IN_DT *y2; + IN_DT *score; + IN_DT *inter_x1; + IN_DT *inter_y1; + IN_DT *inter_x2; + IN_DT *inter_y2; + IN_DT *max_box; // the max score, x1, y1, x2, y2 + IN_DT *x1_mask; + IN_DT *y1_mask; + IN_DT *x2_mask; + IN_DT *y2_mask; + OUT_DT *nram_save; + + int limit = 0; // find limit when GDRAM or SRAM + int len_core = 0; // the length deal by every core + int max_seg_pad = 0; // the max length every repeat + int repeat = 0; + int remain = 0; + int remain_pad = 0; + int input_offset = 0; // offset of input_data for current core + int nram_save_count = 0; + // mask for collect x1, y1, x2, y2. each mask has 128 elements + const int mask_size = 128; + const int total_mask_size = 512; + + if (output_mode == 0) { + limit = (buffer_size - 128 /*for max_box*/ * sizeof(IN_DT) - + nram_save_limit_count * sizeof(OUT_DT) - + total_mask_size * sizeof(IN_DT)) / + (nms_buffer_count1 * sizeof(IN_DT)); + } else { + limit = (buffer_size - 128 /*for max_box*/ * sizeof(IN_DT) - + nram_save_limit_count * INFO_NUM * sizeof(OUT_DT) - + total_mask_size * sizeof(IN_DT)) / + (nms_buffer_count1 * sizeof(IN_DT)); + } + + if (core_limit == 1) { + len_core = input_box_num; + input_offset = 0; + } else { + int avg_core = input_box_num / core_limit; + int rem = input_box_num % core_limit; + len_core = avg_core + (taskId < rem ? 1 : 0); + input_offset = avg_core * taskId + (taskId <= rem ? taskId : rem); + } + max_seg_pad = PAD_DOWN(limit, NMS_SIZE); + repeat = len_core / max_seg_pad; + remain = len_core % max_seg_pad; + remain_pad = PAD_UP(remain, NMS_SIZE); + + // if datatype is half, we should convert it to float when compute the IoU + int max_seg_iou_compute = + PAD_DOWN(max_seg_pad / (sizeof(float) / sizeof(IN_DT)), NMS_SIZE); + int repeat_iou_compute = len_core / max_seg_iou_compute; + int remain_iou_compute = len_core % max_seg_iou_compute; + int remain_pad_iou_compute = PAD_UP(remain_iou_compute, NMS_SIZE); + // initial the address point + score = buffer; + x1 = score + max_seg_pad; + y1 = x1 + max_seg_pad; + x2 = y1 + max_seg_pad; + y2 = x2 + max_seg_pad; + inter_x1 = y2 + max_seg_pad; + inter_y1 = inter_x1 + max_seg_pad; + inter_x2 = inter_y1 + max_seg_pad; + inter_y2 = inter_x2 + max_seg_pad; + x1_mask = inter_y2 + max_seg_pad; + y1_mask = x1_mask + mask_size; + x2_mask = y1_mask + mask_size; + y2_mask = x2_mask + mask_size; + max_box = y2_mask + mask_size; // the max score, x1, y1, x2, y2 + // offset two line from max_box + nram_save = (OUT_DT *)((char *)max_box + NFU_ALIGN_SIZE); + + // set mask for __bang_collect instruction + if (input_layout == 0) { + __nramset((IN_DT *)x1_mask, total_mask_size, (IN_DT)0); + for (int idx = 0; idx < mask_size; idx++) { + int index = (idx % COORD_DIM) * mask_size + idx; + x1_mask[index] = (IN_DT)1.0; + } + } + + for (int keep = 0; keep < keepNum; keep++) { // loop until the max_score <= 0 + if (core_limit != 1) { + __sync_cluster(); // sync before current loop + } + + /******find max start******/ + int max_index = 0; // the max score index + int global_max_index = 0; // for U1 + float max_area = 0; // the max score area + max_box[0] = 0; // init 0 + + for (int i = 0; i <= repeat; i++) { + if (i == repeat && remain == 0) { + break; + } + int seg_len = 0; // the length every nms compute + int cpy_len = 0; // the length every nms memcpy + i == repeat ? seg_len = remain_pad : seg_len = max_seg_pad; + // check seg_len exceeds the limit of fp16 or not. 65536 is the largest + // num that half data type could express. + if (sizeof(IN_DT) == sizeof(half) && seg_len > 65536) { + // seg length exceeds the max num for fp16 datatype! + return; + } + i == repeat ? cpy_len = remain : cpy_len = max_seg_pad; + /******nms load start******/ + mluMemcpyDirection_t load_dir = SRAM2NRAM; + if (src == SRAM) { + load_dir = SRAM2NRAM; + } else { + load_dir = GDRAM2NRAM; + } + __nramset(score, seg_len, (IN_DT)0); + __memcpy(score, input_score_ptr + input_offset + i * max_seg_pad, + cpy_len * sizeof(IN_DT), load_dir, cpy_len * sizeof(IN_DT), + cpy_len * sizeof(IN_DT), 0); + + /******nms load end******/ + + __bang_max(inter_x1, score, seg_len); + if (inter_x1[0] > max_box[0]) { + max_box[0] = inter_x1[0]; + + if (sizeof(IN_DT) == sizeof(half)) { + max_index = ((uint16_t *)inter_x1)[1] + input_offset + + i * max_seg_pad; // offset start from head of input_data + } else if (sizeof(IN_DT) == sizeof(float)) { + max_index = ((uint32_t *)inter_x1)[1] + input_offset + + i * max_seg_pad; // offset start from head of input_data + } + } + } // for repeat + + int stride = 1; + if (input_layout == 0) { + stride = input_stride; + } else if (input_layout == 1) { + stride = 1; + } + + if (core_limit == 1) { + max_box[1] = input_x1_ptr[max_index * stride]; + max_box[2] = input_y1_ptr[max_index * stride]; + max_box[3] = input_x2_ptr[max_index * stride]; + max_box[4] = input_y2_ptr[max_index * stride]; + if (algo == 0 || offset == 0.0) { + max_area = ((float)max_box[3] - (float)max_box[1]) * + ((float)max_box[4] - (float)max_box[2]); + } else { + max_area = ((float)max_box[3] - (float)max_box[1] + offset) * + ((float)max_box[4] - (float)max_box[2] + offset); + } + input_score_ptr[max_index] = 0; + global_max_index = max_index; + ((uint32_t *)(max_box + INFO_NUM))[0] = max_index; + } else if (core_limit == 4) { + // find the max with sram + // the max box's x1, y1, x2, y2 on every core + if (coreId != MEMORY_CORE) { + max_box[1] = input_x1_ptr[max_index * stride]; + max_box[2] = input_y1_ptr[max_index * stride]; + max_box[3] = input_x2_ptr[max_index * stride]; + max_box[4] = input_y2_ptr[max_index * stride]; + } + ((uint32_t *)(max_box + INFO_NUM))[0] = max_index; + // copy every core's box info to sram, form: score---x1---y1---x2---y2--- + for (int i = 0; i < INFO_NUM; i++) { + __memcpy(sram + i * core_limit + taskId, max_box + i, 1 * sizeof(IN_DT), + NRAM2SRAM); + } + // copy every core's max_index to sram, use 2 half to store max_index + __memcpy(sram + INFO_NUM * core_limit + taskId * 2, max_box + INFO_NUM, + sizeof(uint32_t), + NRAM2SRAM); // int32_t datatype + __sync_cluster(); + + // copy score from sram to nram and find the max + __nramset(inter_x1, NMS_SIZE, (IN_DT)0); + __memcpy(inter_x1, sram, core_limit * sizeof(IN_DT), SRAM2NRAM); + __bang_max(max_box, inter_x1, NMS_SIZE); + int max_core = 0; + if (sizeof(IN_DT) == sizeof(half)) { + max_core = ((uint16_t *)max_box)[1]; + } else if (sizeof(IN_DT) == sizeof(float)) { + max_core = ((uint32_t *)max_box)[1]; + } + + // copy the max box from SRAM to NRAM + __memcpy(max_box + 1, sram + 1 * core_limit + max_core, 1 * sizeof(IN_DT), + SRAM2NRAM); // x1 + __memcpy(max_box + 2, sram + 2 * core_limit + max_core, 1 * sizeof(IN_DT), + SRAM2NRAM); // y1 + __memcpy(max_box + 3, sram + 3 * core_limit + max_core, 1 * sizeof(IN_DT), + SRAM2NRAM); // x2 + __memcpy(max_box + 4, sram + 4 * core_limit + max_core, 1 * sizeof(IN_DT), + SRAM2NRAM); // y2 + __memcpy(max_box + 5, sram + 5 * core_limit + 2 * max_core, + sizeof(uint32_t), SRAM2NRAM); + if (algo == 0 || offset == 0.0) { + max_area = ((float)max_box[3] - (float)max_box[1]) * + ((float)max_box[4] - (float)max_box[2]); + } else { + max_area = ((float)max_box[3] - (float)max_box[1] + offset) * + ((float)max_box[4] - (float)max_box[2] + offset); + } + global_max_index = ((uint32_t *)(max_box + INFO_NUM))[0]; + input_score_ptr[global_max_index] = 0; + } + // by now, we get: max_score|max_index|max_box|max_area + /******find max end******/ + + /******nms store start******/ + // store to nram + if (float(max_box[0]) > thresh_score) { + OUT_DT *save_ptr; + int save_offset = 0; + int save_str_num = 0; + save_ptr = nram_save; + save_offset = nram_save_count; + save_str_num = nram_save_limit_count; + if (coreId == 0) { + if (output_mode == 0) { // index1, index2, ... + __memcpy(save_ptr + save_offset, (uint32_t *)(max_box + INFO_NUM), + 1 * sizeof(uint32_t), NRAM2NRAM, 1 * sizeof(uint32_t), + 1 * sizeof(uint32_t), 0); + } else if (output_mode == 1) { // score, x1, y1, x2, y2 + __memcpy(save_ptr + save_offset * INFO_NUM, max_box, + INFO_NUM * sizeof(IN_DT), NRAM2NRAM, + INFO_NUM * sizeof(IN_DT), INFO_NUM * sizeof(IN_DT), 0); + } else if (output_mode == 2) { // score---, x1---, y1---, x2---, y2--- + __memcpy(save_ptr + save_offset, max_box, 1 * sizeof(IN_DT), + NRAM2NRAM, save_str_num * sizeof(IN_DT), 1 * sizeof(IN_DT), + 4); + } + } + nram_save_count++; + (*output_box_num)++; + } + + // store to sram/gdram + if (*output_box_num != 0) { + mluMemcpyDirection_t store_dir = NRAM2GDRAM; + if (dst == SRAM) { + store_dir = NRAM2SRAM; + } else { // dst == GDRAM + store_dir = NRAM2GDRAM; + } + if ((nram_save_count == nram_save_limit_count) || + (float(max_box[0]) <= thresh_score) || keep == keepNum - 1) { + if (nram_save_count != 0) { + if (coreId == 0) { + if (output_mode == 0) { // index1, index2, ... + pvLock(); + __memcpy(output_data, nram_save, + nram_save_count * sizeof(uint32_t), store_dir); + pvUnlock(); + output_data += nram_save_count; + } else if (output_mode == 1) { // score, x1, y1, x2, y2 + pvLock(); + __memcpy(output_data, nram_save, + nram_save_count * INFO_NUM * sizeof(IN_DT), store_dir); + pvUnlock(); + output_data += nram_save_count * INFO_NUM; + } else if (output_mode == + 2) { // score---, x1---, y1---, x2---, y2--- + pvLock(); + __memcpy(output_data, nram_save, nram_save_count * sizeof(IN_DT), + store_dir, output_stride * sizeof(IN_DT), + nram_save_limit_count * sizeof(IN_DT), 4); + pvUnlock(); + output_data += nram_save_count; + } + nram_save_count = 0; + } + } + } // if move data nram->sram/gdram + } // if dst + + // if the max score <= 0, end + if (core_limit == 1) { + if (float(max_box[0]) <= thresh_score) { + break; + } + } else { + if (float(max_box[0]) <= thresh_score) { + if (coreId == 0) { + loop_end_flag[0] = 1; + } + } + __sync_cluster(); + if (loop_end_flag[0] == 1) { + break; + } + } + /******nms store end******/ + + // To solve half data accuracy, we convert half to float to calculate IoU. + for (int i = 0; i <= repeat_iou_compute; i++) { + if (i == repeat_iou_compute && remain_iou_compute == 0) { + break; + } + int seg_len = 0; // the length every nms compute + int cpy_len = 0; // the length every nms memcpy + i == repeat_iou_compute ? seg_len = remain_pad_iou_compute + : seg_len = max_seg_iou_compute; + i == repeat_iou_compute ? cpy_len = remain_iou_compute + : cpy_len = max_seg_iou_compute; + + /******nms load start******/ + mluMemcpyDirection_t load_dir = SRAM2NRAM; + if (src == SRAM) { + load_dir = SRAM2NRAM; + } else { + load_dir = GDRAM2NRAM; + } + + __nramset((float *)score, seg_len, 0.0f); + int dt_offset = 0; + if (sizeof(IN_DT) == sizeof(float)) { + __memcpy(score, input_score_ptr + input_offset + i * max_seg_pad, + cpy_len * sizeof(IN_DT), load_dir, cpy_len * sizeof(IN_DT), + cpy_len * sizeof(IN_DT), 0); + dt_offset = 0; + } else if (sizeof(IN_DT) == sizeof(half)) { + __nramset(x1, seg_len, half(0)); + __memcpy(x1, input_score_ptr + input_offset + i * max_seg_iou_compute, + cpy_len * sizeof(IN_DT), load_dir, cpy_len * sizeof(IN_DT), + cpy_len * sizeof(IN_DT), 0); + __bang_half2float((float *)score, (half *)x1, seg_len); + dt_offset = max_seg_iou_compute; + } + + if (input_layout == 0) { + // the following number 4 means x1, y1, x2, y2 + __memcpy( + inter_x1, + input_x1_ptr + (input_offset + i * max_seg_iou_compute) * COORD_DIM, + cpy_len * COORD_DIM * sizeof(IN_DT), load_dir, + cpy_len * COORD_DIM * sizeof(IN_DT), + cpy_len * COORD_DIM * sizeof(IN_DT), 0); + // here use collect instruction to transpose the [n, 4] shape into [4, + // n] shape to avoid + // discrete memory accessing. + for (int c_i = 0; c_i < COORD_DIM * seg_len / mask_size; c_i++) { + // the following number 32 means 32 elements will be selected out by + // once operation + __bang_collect(x1 + dt_offset + c_i * 32, inter_x1 + c_i * mask_size, + x1_mask, mask_size); + __bang_collect(y1 + dt_offset + c_i * 32, inter_x1 + c_i * mask_size, + y1_mask, mask_size); + __bang_collect(x2 + dt_offset + c_i * 32, inter_x1 + c_i * mask_size, + x2_mask, mask_size); + __bang_collect(y2 + dt_offset + c_i * 32, inter_x1 + c_i * mask_size, + y2_mask, mask_size); + } + } else if (input_layout == 1) { + __memcpy(x1 + dt_offset, + input_x1_ptr + input_offset + i * max_seg_iou_compute, + cpy_len * sizeof(IN_DT), load_dir, cpy_len * sizeof(IN_DT), + cpy_len * sizeof(IN_DT), 0); + __memcpy(y1 + dt_offset, + input_y1_ptr + input_offset + i * max_seg_iou_compute, + cpy_len * sizeof(IN_DT), load_dir, cpy_len * sizeof(IN_DT), + cpy_len * sizeof(IN_DT), 0); + __memcpy(x2 + dt_offset, + input_x2_ptr + input_offset + i * max_seg_iou_compute, + cpy_len * sizeof(IN_DT), load_dir, cpy_len * sizeof(IN_DT), + cpy_len * sizeof(IN_DT), 0); + __memcpy(y2 + dt_offset, + input_y2_ptr + input_offset + i * max_seg_iou_compute, + cpy_len * sizeof(IN_DT), load_dir, cpy_len * sizeof(IN_DT), + cpy_len * sizeof(IN_DT), 0); + } + /******nms load end******/ + + /******nms compute start******/ + if (sizeof(IN_DT) == sizeof(half)) { + __bang_half2float((float *)x1, (half *)x1 + max_seg_iou_compute, + seg_len); + __bang_half2float((float *)y1, (half *)y1 + max_seg_iou_compute, + seg_len); + __bang_half2float((float *)x2, (half *)x2 + max_seg_iou_compute, + seg_len); + __bang_half2float((float *)y2, (half *)y2 + max_seg_iou_compute, + seg_len); + } + // 1、 compute IOU + // get the area_I + __nramset((float *)inter_y1, seg_len, float(max_box[1])); // max_x1 + __bang_maxequal((float *)inter_x1, (float *)x1, (float *)inter_y1, + seg_len); // inter_x1 + __nramset((float *)inter_y2, seg_len, float(max_box[3])); // max_x2 + __bang_minequal((float *)inter_x2, (float *)x2, (float *)inter_y2, + seg_len); // inter_x2 + __bang_sub((float *)inter_x1, (float *)inter_x2, (float *)inter_x1, + seg_len); + if (algo == 1 && offset != 0.0) { + __bang_add_const((float *)inter_x1, (float *)inter_x1, offset, seg_len); + } + __bang_active_relu((float *)inter_x1, (float *)inter_x1, + seg_len); // inter_w + __nramset((float *)inter_x2, seg_len, float(max_box[2])); // max_y1 + __bang_maxequal((float *)inter_y1, (float *)y1, (float *)inter_x2, + seg_len); // inter_y1 + __nramset((float *)inter_x2, seg_len, float(max_box[4])); // max_y2 + __bang_minequal((float *)inter_y2, (float *)y2, (float *)inter_x2, + seg_len); // inter_y2 + __bang_sub((float *)inter_y1, (float *)inter_y2, (float *)inter_y1, + seg_len); + if (algo == 1 && offset != 0.0) { + __bang_add_const((float *)inter_y1, (float *)inter_y1, offset, seg_len); + } + __bang_active_relu((float *)inter_y1, (float *)inter_y1, + seg_len); // inter_h + __bang_mul((float *)inter_x1, (float *)inter_x1, (float *)inter_y1, + seg_len); // area_I + // get the area of input_box: area = (x2 - x1) * (y2 - y1); + __bang_sub((float *)inter_y1, (float *)x2, (float *)x1, seg_len); + __bang_sub((float *)inter_y2, (float *)y2, (float *)y1, seg_len); + if (algo == 1 && offset != 0.0) { + __bang_add_const((float *)inter_y1, (float *)inter_y1, offset, seg_len); + __bang_add_const((float *)inter_y2, (float *)inter_y2, offset, seg_len); + } + __bang_mul((float *)inter_x2, (float *)inter_y1, (float *)inter_y2, + seg_len); // area + // get the area_U: area + max_area - area_I + __bang_add_const((float *)inter_x2, (float *)inter_x2, float(max_area), + seg_len); + __bang_sub((float *)inter_x2, (float *)inter_x2, (float *)inter_x1, + seg_len); // area_U + // 2、 select the box + // if IOU greater than thres, set the score to zero, abort it: area_U > + // area_I * (1 / thresh)? + if (thresh_iou > 0.0) { + __bang_mul_const((float *)inter_x1, (float *)inter_x1, div_thresh_iou, + seg_len); + } else { + __bang_mul_const((float *)inter_x2, (float *)inter_x2, thresh_iou, + seg_len); + } + __bang_ge((float *)inter_x1, (float *)inter_x2, (float *)inter_x1, + seg_len); + __bang_mul((float *)score, (float *)score, (float *)inter_x1, seg_len); + /******nms compute end******/ + + // update the score + mluMemcpyDirection_t update_dir = NRAM2SRAM; + if (dst == SRAM) { + update_dir = NRAM2SRAM; + } else { + update_dir = NRAM2GDRAM; + } + if (sizeof(IN_DT) == sizeof(half)) { + __bang_float2half_rd((half *)score, (float *)score, seg_len); + } + pvLock(); + __memcpy(input_score_ptr + input_offset + i * max_seg_iou_compute, score, + cpy_len * sizeof(IN_DT), update_dir, cpy_len * sizeof(IN_DT), + cpy_len * sizeof(IN_DT), 0); + pvUnlock(); + } // for repeat + } // for keepNum +} + +__mlu_global__ void MLUKernelNMS( + const void *input_boxes, const void *input_confidence, + const int input_num_boxes, const int input_stride, + const int max_output_size, const float iou_threshold, + const float confidence_threshold, const int mode, const int input_layout, + void *workspace, void *result_num, void *output, + const cnrtDataType_t data_type_input, const float offset, const int algo) { + if (data_type_input == CNRT_FLOAT16) { + __memcpy(workspace, input_confidence, input_num_boxes * sizeof(half), + GDRAM2GDRAM); + } else if (data_type_input == CNRT_FLOAT32) { + __memcpy(workspace, input_confidence, input_num_boxes * sizeof(float), + GDRAM2GDRAM); + } else { + } + + int output_stride = max_output_size; + uint32_t result_box_num = 0; + if (mode == 0) { + uint32_t *out_data = (uint32_t *)output; + switch (data_type_input) { + default: { return; } + case CNRT_FLOAT16: { + half *boxes_data = (half *)input_boxes; + half *confi_data = (half *)workspace; + half *buffer = (half *)nram_buffer; + half *sram = (half *)sram_buffer; + + nms_detection(&result_box_num, mode, input_layout, out_data, GDRAM, + confi_data, boxes_data, GDRAM, buffer, SIZE_NRAM_BUF, + sram, taskDim, input_num_boxes, input_stride, + output_stride, max_output_size, iou_threshold, + confidence_threshold, offset, algo); + ((uint32_t *)result_num)[0] = result_box_num; + }; break; + case CNRT_FLOAT32: { + float *boxes_data = (float *)input_boxes; + float *confi_data = (float *)workspace; + float *buffer = (float *)nram_buffer; + float *sram = (float *)sram_buffer; + + nms_detection(&result_box_num, mode, input_layout, out_data, GDRAM, + confi_data, boxes_data, GDRAM, buffer, SIZE_NRAM_BUF, + sram, taskDim, input_num_boxes, input_stride, + output_stride, max_output_size, iou_threshold, + confidence_threshold, offset, algo); + ((uint32_t *)result_num)[0] = result_box_num; + }; break; + } + } else { + switch (data_type_input) { + default: { return; } + case CNRT_FLOAT16: { + half *boxes_data = (half *)input_boxes; + half *confi_data = (half *)workspace; + half *out_data = (half *)output; + half *buffer = (half *)nram_buffer; + half *sram = (half *)sram_buffer; + + nms_detection(&result_box_num, mode, input_layout, out_data, GDRAM, + confi_data, boxes_data, GDRAM, buffer, SIZE_NRAM_BUF, + sram, taskDim, input_num_boxes, input_stride, + output_stride, max_output_size, iou_threshold, + confidence_threshold, offset, algo); + ((uint32_t *)result_num)[0] = result_box_num; + }; break; + case CNRT_FLOAT32: { + float *boxes_data = (float *)input_boxes; + float *confi_data = (float *)workspace; + float *out_data = (float *)output; + float *buffer = (float *)nram_buffer; + float *sram = (float *)sram_buffer; + + nms_detection(&result_box_num, mode, input_layout, out_data, GDRAM, + confi_data, boxes_data, GDRAM, buffer, SIZE_NRAM_BUF, + sram, taskDim, input_num_boxes, input_stride, + output_stride, max_output_size, iou_threshold, + confidence_threshold, offset, algo); + ((uint32_t *)result_num)[0] = result_box_num; + }; break; + } + } +} + +void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t data_type_input, const void *boxes_ptr, + const void *scores_ptr, const int input_num_boxes, + const int input_stride, const int max_output_boxes, + const float iou_threshold, const float offset, + void *workspace_ptr, void *output_size_ptr, void *output_ptr) { + MLUKernelNMS<<>>( + boxes_ptr, scores_ptr, input_num_boxes, input_stride, max_output_boxes, + iou_threshold, /*confidence_threshold=*/0.0, /*output_mode=*/0, + /*input_layout=*/0, workspace_ptr, output_size_ptr, output_ptr, + data_type_input, offset, /*algo=*/1); +} diff --git a/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp b/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp index 4f198ac37b..15c5333712 100644 --- a/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp @@ -19,6 +19,9 @@ using namespace at; #define CHECK_CUDA_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) +#define CHECK_MLU_INPUT(x) \ + CHECK_MLU(x); \ + CHECK_CONTIGUOUS(x) #define CHECK_CPU_INPUT(x) \ CHECK_CPU(x); \ CHECK_CONTIGUOUS(x) diff --git a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp new file mode 100644 index 0000000000..af193fce33 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp @@ -0,0 +1,96 @@ +/************************************************************************* + * Copyright (C) 2021 by Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "pytorch_mlu_helper.hpp" + +void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t data_type_input, const void *boxes_ptr, + const void *scores_ptr, const int input_num_boxes, + const int input_stride, const int max_output_boxes, + const float iou_threshold, const float offset, + void *workspace_ptr, void *output_size_ptr, void *output_ptr); + +Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, + int offset) { + // dimension parameters check + TORCH_CHECK(boxes.dim() == 2, "boxes should be a 2d tensor, got ", + boxes.dim(), "D"); + TORCH_CHECK(boxes.size(1) == 4, + "boxes should have 4 elements in dimension 1, got ", + boxes.size(1)); + TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", + scores.dim(), "D"); + + // data type check + TORCH_CHECK(boxes.scalar_type() == scores.scalar_type(), + "boxes should have the same type as scores"); + TORCH_CHECK( + boxes.scalar_type() == at::kFloat || boxes.scalar_type() == at::kHalf, + "data type of boxes should be Float or Half, got ", boxes.scalar_type()); + + if (boxes.numel() == 0) { + return at::empty({0}, boxes.options().dtype(at::kLong)); + } + + int input_num_boxes = boxes.size(0); + int input_stride = boxes.size(1); + int max_output_boxes = boxes.size(0); + cnrtJobType_t k_type = CNRT_FUNC_TYPE_UNION1; + int core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + uint32_t dim_x = core_dim; + cnrtDim3_t k_dim = {dim_x, 1, 1}; + cnrtDataType_t data_type_input = torch_mlu::toCnrtDtype(boxes.dtype()); + + auto output = at::empty({max_output_boxes}, boxes.options().dtype(at::kLong)); + auto output_size = at::empty({1}, scores.options().dtype(at::kInt)); + + // workspace + size_t space_size = 0; + if (boxes.scalar_type() == at::kHalf) { + space_size = input_num_boxes * sizeof(int16_t); + } else { + space_size = input_num_boxes * sizeof(float); + } + auto workspace = at::empty(space_size, boxes.options().dtype(at::kByte)); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + auto boxes_impl = torch_mlu::getMluTensorImpl(boxes); + auto boxes_ptr = boxes_impl->cnnlMalloc(); + auto scores_impl = torch_mlu::getMluTensorImpl(scores); + auto scores_ptr = scores_impl->cnnlMalloc(); + auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); + auto workspace_ptr = workspace_impl->cnnlMalloc(); + auto output_impl = torch_mlu::getMluTensorImpl(output); + auto output_ptr = output_impl->cnnlMalloc(); + auto output_size_impl = torch_mlu::getMluTensorImpl(output_size); + auto output_size_ptr = output_size_impl->cnnlMalloc(); + + switch (k_type) { + default: { + TORCH_CHECK(false, "[nms_mlu]:Failed to choose kernel to launch"); + } + case CNRT_FUNC_TYPE_BLOCK: + case CNRT_FUNC_TYPE_UNION1: { + CNLOG(INFO) << "Launch Kernel MLUUnion1 or Block NMS<<>>"; + KernelNms(k_dim, k_type, queue, data_type_input, boxes_ptr, scores_ptr, + input_num_boxes, input_stride, max_output_boxes, iou_threshold, + offset, workspace_ptr, output_size_ptr, output_ptr); + }; break; + } + + int output_num = *static_cast(output_size.cpu().data_ptr()); + return output.slice(0, 0, output_num); +} diff --git a/mmcv/ops/csrc/pytorch/nms.cpp b/mmcv/ops/csrc/pytorch/nms.cpp index e88208dc9f..8d6844e9ff 100644 --- a/mmcv/ops/csrc/pytorch/nms.cpp +++ b/mmcv/ops/csrc/pytorch/nms.cpp @@ -10,6 +10,15 @@ Tensor nms_cuda(Tensor boxes, Tensor scores, float iou_threshold, int offset) { } #endif +#ifdef MMCV_WITH_MLU +Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, + int offset); + +Tensor nms_mlu(Tensor boxes, Tensor scores, float iou_threshold, int offset) { + return NMSMLUKernelLauncher(boxes, scores, iou_threshold, offset); +} +#endif + Tensor nms_cpu(Tensor boxes, Tensor scores, float iou_threshold, int offset) { if (boxes.numel() == 0) { return at::empty({0}, boxes.options().dtype(at::kLong)); @@ -69,6 +78,12 @@ Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset) { return nms_cuda(boxes, scores, iou_threshold, offset); #else AT_ERROR("nms is not compiled with GPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (boxes.device().type() == at::kMLU) { + CHECK_MLU_INPUT(boxes); + CHECK_MLU_INPUT(scores); + return nms_mlu(boxes, scores, iou_threshold, offset); #endif } else { CHECK_CPU_INPUT(boxes); diff --git a/tests/test_ops/test_nms.py b/tests/test_ops/test_nms.py index 3c59204b1b..4831f6f644 100644 --- a/tests/test_ops/test_nms.py +++ b/tests/test_ops/test_nms.py @@ -2,12 +2,22 @@ import pytest import torch +from mmcv.utils import is_cuda, is_mlu + class Testnms(object): - def test_nms_allclose(self): - if not torch.cuda.is_available(): - return + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not is_cuda(), reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not is_mlu(), reason='requires MLU support')) + ]) + def test_nms_allclose(self, device): from mmcv.ops import nms np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0], [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]], @@ -23,7 +33,7 @@ def test_nms_allclose(self): assert np.allclose(dets, np_dets) # test cpu assert np.allclose(inds, np_inds) # test cpu dets, inds = nms( - boxes.cuda(), scores.cuda(), iou_threshold=0.3, offset=0) + boxes.to(device), scores.to(device), iou_threshold=0.3, offset=0) assert np.allclose(dets.cpu().numpy(), np_dets) # test gpu assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu From e1c73552a4062205c81a4dcea255e584a07d6e30 Mon Sep 17 00:00:00 2001 From: zhouchenyang Date: Tue, 23 Nov 2021 10:28:07 +0800 Subject: [PATCH 17/30] [Feature] Support BBoxOverlaps with cambricon MLU backend (#1507) --- .../common/mlu/bbox_overlaps_mlu_kernel.mlu | 322 ++++++++++++++++++ mmcv/ops/csrc/pytorch/bbox_overlaps.cpp | 19 ++ .../csrc/pytorch/mlu/bbox_overlaps_mlu.cpp | 90 +++++ tests/test_ops/test_bbox.py | 50 ++- 4 files changed, 465 insertions(+), 16 deletions(-) create mode 100644 mmcv/ops/csrc/common/mlu/bbox_overlaps_mlu_kernel.mlu create mode 100644 mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp diff --git a/mmcv/ops/csrc/common/mlu/bbox_overlaps_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/bbox_overlaps_mlu_kernel.mlu new file mode 100644 index 0000000000..58e695a015 --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/bbox_overlaps_mlu_kernel.mlu @@ -0,0 +1,322 @@ +/************************************************************************* + * Copyright (C) 2021 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include + +#include "common_mlu_helper.hpp" + +#define COORD_NUM 4 + +__nram__ char nmem_buf[MAX_NRAM_SIZE]; + +template +__mlu_func__ void computeDiv(void *nram_dst, void *nram_src0, void *nram_src1, + void *nram_addition, const int32_t deal_num) { + __bang_active_reciphp((T *)nram_dst, (T *)nram_src1, deal_num); + __bang_mul((T *)nram_dst, (T *)nram_src0, (T *)nram_dst, deal_num); +} + +template <> +__mlu_func__ void computeDiv(void *nram_dst, void *nram_src0, + void *nram_src1, void *nram_addition, + const int32_t deal_num) { + __bang_half2float((float *)nram_addition, (half *)nram_src1, deal_num); + __bang_active_reciphp((float *)nram_addition, (float *)nram_addition, + deal_num); + __bang_float2half_rd((half *)nram_src1, (float *)nram_addition, deal_num); + __bang_mul((half *)nram_dst, (half *)nram_src0, (half *)nram_src1, deal_num); +} + +template +__mlu_func__ void bboxOverlapsWorkflow( + T *vec_b1_x1, T *vec_b1_y1, T *vec_b1_x2, T *vec_b1_y2, T *vec_b2_x1, + T *vec_b2_y1, T *vec_b2_x2, T *vec_b2_y2, T *vec_left, T *vec_right, + T *vec_top, T *vec_bottom, const T *bbox1, const T *bbox2, void *ious, + const int32_t offset, const int32_t mode, const int32_t batches_stride, + const int32_t num_bbox1, const int32_t num_bbox2, const bool aligned) { + int32_t task_batch_stride = (num_bbox1 + taskDim - 1) / taskDim; + int32_t batch_start = taskId * task_batch_stride; + int32_t batch_per_task = batch_start + task_batch_stride < num_bbox1 + ? task_batch_stride + : num_bbox1 - batch_start; + batch_per_task = batch_per_task > 0 ? batch_per_task : (0); + + if (aligned) { + int32_t num_loop_cpy = batch_per_task / batches_stride; + int32_t num_rem_cpy_batches = batch_per_task % batches_stride; + num_loop_cpy = num_rem_cpy_batches > 0 ? num_loop_cpy + 1 : num_loop_cpy; + for (int32_t i = 0; i < num_loop_cpy; i++) { + int32_t index = batch_start + i * batches_stride; + int32_t handle_batches = index + batches_stride > num_bbox1 + ? num_rem_cpy_batches + : batches_stride; + int32_t b1 = index; + int32_t b2 = index; + + int32_t base1 = b1 * COORD_NUM; + __memcpy(vec_b1_x1, &bbox1[base1], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + __memcpy(vec_b1_y1, &bbox1[base1 + 1], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + __memcpy(vec_b1_x2, &bbox1[base1 + 2], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + __memcpy(vec_b1_y2, &bbox1[base1 + 3], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + + int32_t base2 = b2 * COORD_NUM; + __memcpy(vec_b2_x1, &bbox2[base2], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + __memcpy(vec_b2_y1, &bbox2[base2 + 1], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + __memcpy(vec_b2_x2, &bbox2[base2 + 2], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + __memcpy(vec_b2_y2, &bbox2[base2 + 3], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + // get the width and height + __bang_maxequal(vec_left, vec_b1_x1, vec_b2_x1, batches_stride); + __bang_minequal(vec_right, vec_b1_x2, vec_b2_x2, batches_stride); + __bang_maxequal(vec_top, vec_b1_y1, vec_b2_y1, batches_stride); + __bang_minequal(vec_bottom, vec_b1_y2, vec_b2_y2, batches_stride); + + // right - left + offset ---> left + __bang_sub(vec_left, vec_right, vec_left, batches_stride); + __bang_add_const(vec_left, vec_left, (T)offset, batches_stride); + + // bottom - top + offset ---> right + __bang_sub(vec_right, vec_bottom, vec_top, batches_stride); + __bang_add_const(vec_right, vec_right, (T)offset, batches_stride); + + // zero vector ---> bottom + __nramset(vec_bottom, batches_stride, 0.f); + + // width --> vec_left + __bang_maxequal(vec_left, vec_bottom, vec_left, batches_stride); + T *width = vec_left; + // height --> vec_right + __bang_maxequal(vec_right, vec_bottom, vec_right, batches_stride); + T *height = vec_right; + + // get the b1_area + // (b1_x2 - b1_x1 + offset) ---> vec_top + __bang_sub(vec_top, vec_b1_x2, vec_b1_x1, batches_stride); + __bang_add_const(vec_top, vec_top, (T)offset, batches_stride); + + // (b1_y2 - b1_y1 + offset) ---> vec_bottom + __bang_sub(vec_bottom, vec_b1_y2, vec_b1_y1, batches_stride); + __bang_add_const(vec_bottom, vec_bottom, (T)offset, batches_stride); + + // b1_area = (b1_x2 - b1_x1 + offset) * (b1_y2 - b1_y1 + offset) + // ---> vec_top; + __bang_mul(vec_top, vec_top, vec_bottom, batches_stride); + T *b1_area = vec_top; + + // get the b2_area + // (b2_x2 - b2_x1 + offset) ---> b2_x1 + __bang_sub(vec_b2_x1, vec_b2_x2, vec_b2_x1, batches_stride); + __bang_add_const(vec_b2_x1, vec_b2_x1, (T)offset, batches_stride); + + // (b2_y2 - b2_y1 + offset) ---> b2_y1 + __bang_sub(vec_b2_y1, vec_b2_y2, vec_b2_y1, batches_stride); + __bang_add_const(vec_b2_y1, vec_b2_y1, (T)offset, batches_stride); + + // b2_area = (b2_x2 - b2_x1 + offset) * (b2_y2 - b2_y1 + offset) + // ---> b2_x1; + __bang_mul(vec_b2_x1, vec_b2_x1, vec_b2_y1, batches_stride); + T *b2_area = vec_b2_x1; + + // inter_s = width * height + __bang_mul(height, width, height, batches_stride); + T *inter_s = height; + + // offset vector ---> vec_b2_y1 + __nramset(vec_b2_y1, batches_stride, T(offset)); + T *vec_offset = vec_b2_y1; + + if (mode == 0) { + __bang_add(b1_area, b1_area, b2_area, batches_stride); + __bang_sub(b1_area, b1_area, inter_s, batches_stride); + __bang_maxequal(b1_area, vec_offset, b1_area, batches_stride); + } else { + __bang_maxequal(b1_area, vec_offset, b1_area, batches_stride); + } + T *base_s = b1_area; + + // ious = inter_s / base_s + computeDiv(width, inter_s, base_s, vec_b2_x2, batches_stride); + __memcpy((T *)ious + index, width, handle_batches * sizeof(T), + NRAM2GDRAM); + } + } else { + int32_t num_loop_cpy = num_bbox2 / batches_stride; + int32_t num_rem_cpy_batches = num_bbox2 % batches_stride; + num_loop_cpy = num_rem_cpy_batches > 0 ? num_loop_cpy + 1 : num_loop_cpy; + for (int32_t i = 0; i < batch_per_task; i++) { + int32_t index1 = batch_start + i; + int32_t b1 = index1; + int32_t base1 = b1 * COORD_NUM; + + // set bbox1 and bbox2 to nram + __nramset(vec_b1_x1, batches_stride, bbox1[base1]); + __nramset(vec_b1_y1, batches_stride, bbox1[base1 + 1]); + __nramset(vec_b1_x2, batches_stride, bbox1[base1 + 2]); + __nramset(vec_b1_y2, batches_stride, bbox1[base1 + 3]); + + for (int32_t j = 0; j < num_loop_cpy; j++) { + int32_t index2 = j * batches_stride; + int32_t handle_batches = index2 + batches_stride > num_bbox2 + ? num_rem_cpy_batches + : batches_stride; + int32_t b2 = index2; + int32_t base2 = b2 * COORD_NUM; + + // copy bbox2 to nram + __memcpy(vec_b2_x1, &bbox2[base2], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + __memcpy(vec_b2_y1, &bbox2[base2 + 1], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + __memcpy(vec_b2_x2, &bbox2[base2 + 2], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + __memcpy(vec_b2_y2, &bbox2[base2 + 3], sizeof(T), GDRAM2NRAM, sizeof(T), + COORD_NUM * sizeof(T), handle_batches - 1); + + // get the width and height + __bang_maxequal(vec_left, vec_b1_x1, vec_b2_x1, batches_stride); + __bang_minequal(vec_right, vec_b1_x2, vec_b2_x2, batches_stride); + __bang_maxequal(vec_top, vec_b1_y1, vec_b2_y1, batches_stride); + __bang_minequal(vec_bottom, vec_b1_y2, vec_b2_y2, batches_stride); + + // right - left + offset ---> left + __bang_sub(vec_left, vec_right, vec_left, batches_stride); + __bang_add_const(vec_left, vec_left, (T)offset, batches_stride); + // bottom - top + offset ---> right + __bang_sub(vec_right, vec_bottom, vec_top, batches_stride); + __bang_add_const(vec_right, vec_right, (T)offset, batches_stride); + + // zero vector ---> bottom + __nramset(vec_bottom, batches_stride, (T)0); + + // width --> vec_left + __bang_maxequal(vec_left, vec_bottom, vec_left, batches_stride); + T *width = vec_left; + // height --> vec_right + __bang_maxequal(vec_right, vec_bottom, vec_right, batches_stride); + T *height = vec_right; + + // get the b1_area + // (b1_x2 - b1_x1 + offset) ---> vec_top + __bang_sub(vec_top, vec_b1_x2, vec_b1_x1, batches_stride); + __bang_add_const(vec_top, vec_top, (T)offset, batches_stride); + // (b1_y2 - b1_y1 + offset) ---> vec_bottom + __bang_sub(vec_bottom, vec_b1_y2, vec_b1_y1, batches_stride); + __bang_add_const(vec_bottom, vec_bottom, (T)offset, batches_stride); + // b1_area = (b1_x2 - b1_x1 + offset) * (b1_y2 - b1_y1 + offset) + // ---> vec_top; + __bang_mul(vec_top, vec_top, vec_bottom, batches_stride); + T *b1_area = vec_top; + + // get the b2_area + // (b2_x2 - b2_x1 + offset) ---> b2_x1 + __bang_sub(vec_b2_x1, vec_b2_x2, vec_b2_x1, batches_stride); + __bang_add_const(vec_b2_x1, vec_b2_x1, (T)offset, batches_stride); + // (b2_y2 - b2_y1 + offset) ---> b2_y1 + __bang_sub(vec_b2_y1, vec_b2_y2, vec_b2_y1, batches_stride); + __bang_add_const(vec_b2_y1, vec_b2_y1, (T)offset, batches_stride); + // b2_area = (b2_x2 - b2_x1 + offset) * (b2_y2 - b2_y1 + offset) + // ---> b2_x1; + __bang_mul(vec_b2_x1, vec_b2_x1, vec_b2_y1, batches_stride); + T *b2_area = vec_b2_x1; + + // inter_s = width * height + __bang_mul(height, width, height, batches_stride); + T *inter_s = height; + + // offset vector ---> vec_b2_y1 + __nramset(vec_b2_y1, batches_stride, T(offset)); + T *vec_offset = vec_b2_y1; + + if (mode == 0) { + __bang_add(b1_area, b1_area, b2_area, batches_stride); + __bang_sub(b1_area, b1_area, inter_s, batches_stride); + __bang_maxequal(b1_area, vec_offset, b1_area, batches_stride); + } else { + __bang_maxequal(b1_area, vec_offset, b1_area, batches_stride); + } + T *base_s = b1_area; + + // ious = inter_s / base_s + computeDiv(width, inter_s, base_s, vec_b2_x2, batches_stride); + int32_t gdram_offset = index1 * num_bbox2 + index2; + __memcpy((T *)ious + gdram_offset, width, handle_batches * sizeof(T), + NRAM2GDRAM); + } + } + } +} + +template +__mlu_global__ void MLUUnion1KernelBBoxOverlaps( + const void *bbox1, const void *bbox2, void *ious, const int32_t num_bbox1, + const int32_t num_bbox2, const int32_t mode, const bool aligned, + const int32_t offset) { + /* + * NRAM partition + * |-------------------------------------------------------------| + * | vec_b1_x1 | vec_b1_y1 | vec_b1_x2 | vec_b1_y2 | + * |-------------------------------------------------------------| + * | vec_b2_x1 | vec_b2_y1 | vec_b2_x2 | vec_b2_y2 | + * |-------------------------------------------------------------| + * | vec_left | vec_right | vec_top | vec_bottom | + * |-------------------------------------------------------------| + * + */ + const int32_t align_bytes = PAD_DOWN(MAX_NRAM_SIZE, NFU_ALIGN_SIZE); + const int32_t split_nram_num = 12; + const int32_t nram_stride = + align_bytes / NFU_ALIGN_SIZE / split_nram_num * NFU_ALIGN_SIZE; + + void *vec_b1_x1 = nmem_buf; + void *vec_b1_y1 = nmem_buf + nram_stride; + void *vec_b1_x2 = nmem_buf + 2 * nram_stride; + void *vec_b1_y2 = nmem_buf + 3 * nram_stride; + + void *vec_b2_x1 = nmem_buf + 4 * nram_stride; + void *vec_b2_y1 = nmem_buf + 5 * nram_stride; + void *vec_b2_x2 = nmem_buf + 6 * nram_stride; + void *vec_b2_y2 = nmem_buf + 7 * nram_stride; + + void *vec_left = nmem_buf + 8 * nram_stride; + void *vec_right = nmem_buf + 9 * nram_stride; + void *vec_top = nmem_buf + 10 * nram_stride; + void *vec_bottom = nmem_buf + 11 * nram_stride; + + const int32_t vec_length = nram_stride / sizeof(T); + bboxOverlapsWorkflow((T *)vec_b1_x1, (T *)vec_b1_y1, (T *)vec_b1_x2, + (T *)vec_b1_y2, (T *)vec_b2_x1, (T *)vec_b2_y1, + (T *)vec_b2_x2, (T *)vec_b2_y2, (T *)vec_left, + (T *)vec_right, (T *)vec_top, (T *)vec_bottom, + (T *)bbox1, (T *)bbox2, (T *)ious, offset, mode, + vec_length, num_bbox1, num_bbox2, aligned); +} + +void KernelBBoxOverlaps(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t d_type, + const void *bbox1, const void *bbox2, void *ious, + const int32_t num_bbox1, const int32_t num_bbox2, + const int32_t mode, const bool aligned, + const int32_t offset) { + if (d_type == CNRT_FLOAT16) { + MLUUnion1KernelBBoxOverlaps<<>>( + bbox1, bbox2, ious, num_bbox1, num_bbox2, mode, aligned, offset); + } else { + MLUUnion1KernelBBoxOverlaps<<>>( + bbox1, bbox2, ious, num_bbox1, num_bbox2, mode, aligned, offset); + } +} diff --git a/mmcv/ops/csrc/pytorch/bbox_overlaps.cpp b/mmcv/ops/csrc/pytorch/bbox_overlaps.cpp index 073110dfc8..f4664c9e82 100644 --- a/mmcv/ops/csrc/pytorch/bbox_overlaps.cpp +++ b/mmcv/ops/csrc/pytorch/bbox_overlaps.cpp @@ -12,6 +12,17 @@ void bbox_overlaps_cuda(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, } #endif +#ifdef MMCV_WITH_MLU +void BBoxOverlapsMLUKernelLauncher(const Tensor bboxes1, const Tensor bboxes2, + Tensor ious, const int mode, + const bool aligned, const int offset); + +void bbox_overlaps_mlu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, + const int mode, const bool aligned, const int offset) { + BBoxOverlapsMLUKernelLauncher(bboxes1, bboxes2, ious, mode, aligned, offset); +} +#endif + void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, const bool aligned, const int offset) { if (bboxes1.device().is_cuda()) { @@ -23,6 +34,14 @@ void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, bbox_overlaps_cuda(bboxes1, bboxes2, ious, mode, aligned, offset); #else AT_ERROR("bbox_overlaps is not compiled with GPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (bboxes1.device().type() == at::kMLU) { + CHECK_MLU_INPUT(bboxes1); + CHECK_MLU_INPUT(bboxes2); + CHECK_MLU_INPUT(ious); + + bbox_overlaps_mlu(bboxes1, bboxes2, ious, mode, aligned, offset); #endif } else { AT_ERROR("bbox_overlaps is not implemented on CPU"); diff --git a/mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp new file mode 100644 index 0000000000..3d4b022bb5 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp @@ -0,0 +1,90 @@ +/************************************************************************* + * Copyright (C) 2021 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "pytorch_mlu_helper.hpp" + +void KernelBBoxOverlaps(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t d_type, + const void *bbox1, const void *bbox2, void *ious, + const int32_t num_bbox1, const int32_t num_bbox2, + const int32_t mode, const bool aligned, + const int32_t offset); + +static void policyFunc(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type, + const int32_t batch_num_all) { + auto union_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + auto core_num = union_num * core_dim; + + // Union1 policyFunc + *k_type = CNRT_FUNC_TYPE_UNION1; + k_dim->x = core_dim; + auto need_core_num = PAD_UP(batch_num_all, core_dim); + k_dim->y = + (need_core_num < core_num) ? (need_core_num / core_dim) : union_num; + k_dim->z = 1; + + return; +} + +void BBoxOverlapsMLUKernelLauncher(const Tensor bboxes1, const Tensor bboxes2, + Tensor ious, const int32_t mode, + const bool aligned, const int32_t offset) { + // check dtype + TORCH_CHECK( + bboxes1.scalar_type() == at::kFloat || bboxes1.scalar_type() == at::kHalf, + "Data type of input should be Float or Half. But now input type is ", + bboxes1.scalar_type(), "."); + TORCH_CHECK(bboxes1.scalar_type() == bboxes2.scalar_type(), + "bboxes1's dtype should be the same with bboxes2's dtype."); + + // params check + TORCH_CHECK(bboxes1.dim() == 2, "bboxes1 should be a 2d tensor, got ", + bboxes1.dim(), "D"); + TORCH_CHECK(bboxes2.dim() == 2, "bboxes2 should be a 2d tensor, got ", + bboxes2.dim(), "D"); + + auto rows = bboxes1.size(0); + auto cols = bboxes2.size(0); + auto batch_num_all = rows; + + if (rows * cols == 0) { + // return if zero element + return; + } + + // calculate task dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + policyFunc(&k_dim, &k_type, batch_num_all); + + // get compute queue + cnrtQueue_t queue = torch_mlu::getCurQueue(); + + // get dtype of input + cnrtDataType_t d_type = torch_mlu::toCnrtDtype(bboxes1.dtype()); + + // get ptr of tensors + auto bboxes1_impl = torch_mlu::getMluTensorImpl(bboxes1); + auto bboxes1_ptr = bboxes1_impl->cnnlMalloc(); + auto bboxes2_impl = torch_mlu::getMluTensorImpl(bboxes2); + auto bboxes2_ptr = bboxes2_impl->cnnlMalloc(); + auto ious_impl = torch_mlu::getMluTensorImpl(ious); + auto ious_ptr = ious_impl->cnnlMalloc(); + + // launch kernel + CNLOG(INFO) << "Launch Kernel MLUUnion1BboxOverlapsKernel"; + CNLOG(INFO) << "kDim :[ " << k_dim.x << ", " << k_dim.y << ", " << k_dim.z + << " ]"; + KernelBBoxOverlaps(k_dim, k_type, queue, d_type, bboxes1_ptr, bboxes2_ptr, + ious_ptr, rows, cols, mode, aligned, offset); +} diff --git a/tests/test_ops/test_bbox.py b/tests/test_ops/test_bbox.py index cff7bcca6c..06bafeb1b4 100644 --- a/tests/test_ops/test_bbox.py +++ b/tests/test_ops/test_bbox.py @@ -2,41 +2,59 @@ import pytest import torch +from mmcv.utils import is_cuda, is_mlu -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -class TestBBox(object): - def _test_bbox_overlaps(self, dtype=torch.float): +class TestBBox(object): + def _test_bbox_overlaps(self, device, dtype=torch.float): from mmcv.ops import bbox_overlaps b1 = torch.tensor([[1.0, 1.0, 3.0, 4.0], [2.0, 2.0, 3.0, 4.0], - [7.0, 7.0, 8.0, 8.0]]).cuda().type(dtype) + [7.0, 7.0, 8.0, 8.0]]).to(device).type(dtype) b2 = torch.tensor([[0.0, 2.0, 2.0, 5.0], [2.0, 1.0, 3.0, - 3.0]]).cuda().type(dtype) + 3.0]]).to(device).type(dtype) should_output = np.array([[0.33333334, 0.5], [0.2, 0.5], [0.0, 0.0]]) out = bbox_overlaps(b1, b2, offset=1) assert np.allclose(out.cpu().numpy(), should_output, 1e-2) b1 = torch.tensor([[1.0, 1.0, 3.0, 4.0], [2.0, 2.0, 3.0, - 4.0]]).cuda().type(dtype) + 4.0]]).to(device).type(dtype) b2 = torch.tensor([[0.0, 2.0, 2.0, 5.0], [2.0, 1.0, 3.0, - 3.0]]).cuda().type(dtype) + 3.0]]).to(device).type(dtype) should_output = np.array([0.33333334, 0.5]) out = bbox_overlaps(b1, b2, aligned=True, offset=1) assert np.allclose(out.cpu().numpy(), should_output, 1e-2) - b1 = torch.tensor([[0.0, 0.0, 3.0, 3.0]]).cuda().type(dtype) - b1 = torch.tensor([[0.0, 0.0, 3.0, 3.0]]).cuda().type(dtype) + b1 = torch.tensor([[0.0, 0.0, 3.0, 3.0]]).to(device).type(dtype) b2 = torch.tensor([[4.0, 0.0, 5.0, 3.0], [3.0, 0.0, 4.0, 3.0], [2.0, 0.0, 3.0, 3.0], [1.0, 0.0, 2.0, - 3.0]]).cuda().type(dtype) + 3.0]]).to(device).type(dtype) should_output = np.array([0, 0.2, 0.5, 0.5]) out = bbox_overlaps(b1, b2, offset=1) assert np.allclose(out.cpu().numpy(), should_output, 1e-2) - def test_bbox_overlaps_float(self): - self._test_bbox_overlaps(torch.float) - - def test_bbox_overlaps_half(self): - self._test_bbox_overlaps(torch.half) + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not is_cuda(), reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not is_mlu(), reason='requires MLU support')) + ]) + def test_bbox_overlaps_float(self, device): + self._test_bbox_overlaps(device, dtype=torch.float) + + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not is_cuda(), reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not is_mlu(), reason='requires MLU support')) + ]) + def test_bbox_overlaps_half(self, device): + self._test_bbox_overlaps(device, dtype=torch.half) From 430cfa6863403238eddb5087c42ad30403537eaf Mon Sep 17 00:00:00 2001 From: Wangjiazhen <841713301@qq.com> Date: Tue, 7 Dec 2021 16:49:02 +0800 Subject: [PATCH 18/30] [Refactor] Format C++ code --- mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp | 2 +- mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp | 9 +++++---- mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp | 2 +- mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp index f97ba24db4..82d55559c5 100644 --- a/mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp @@ -10,8 +10,8 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_mlu_helper.hpp" #include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" void KernelBBoxOverlaps(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t d_type, diff --git a/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp index b35c68bd47..5bd545367f 100644 --- a/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp @@ -12,8 +12,8 @@ #include #include -#include "pytorch_mlu_helper.hpp" #include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" void KernelFocalLossSigmoidForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, @@ -328,7 +328,6 @@ void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target, Tensor weight, gamma, alpha); } - void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha); @@ -336,5 +335,7 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha); -REGISTER_DEVICE_IMPL(sigmoid_focal_loss_forward_impl, MLU, sigmoid_focal_loss_forward_mlu); -REGISTER_DEVICE_IMPL(sigmoid_focal_loss_backward_impl, MLU, sigmoid_focal_loss_backward_mlu); +REGISTER_DEVICE_IMPL(sigmoid_focal_loss_forward_impl, MLU, + sigmoid_focal_loss_forward_mlu); +REGISTER_DEVICE_IMPL(sigmoid_focal_loss_backward_impl, MLU, + sigmoid_focal_loss_backward_mlu); diff --git a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp index 5dc933125f..18df00dbce 100644 --- a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp @@ -10,8 +10,8 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_mlu_helper.hpp" #include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t data_type_input, const void *boxes_ptr, diff --git a/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp index 5cb7024b00..077dbfc51e 100644 --- a/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp @@ -9,8 +9,8 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_mlu_helper.hpp" #include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t d_type, From 105cccb619f5286d88ef8c1c03b981dd09569b28 Mon Sep 17 00:00:00 2001 From: Wangjiazhen <841713301@qq.com> Date: Mon, 13 Dec 2021 18:25:11 +0800 Subject: [PATCH 19/30] [Refactor] include common_mlu_helper in pytorch_mlu_helper and refactor build condition --- mmcv/ops/csrc/common/pytorch_mlu_helper.hpp | 1 + setup.py | 47 +++++++++++---------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp b/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp index cd6fc568bb..4e16cabc45 100644 --- a/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp @@ -14,6 +14,7 @@ #ifdef MMCV_WITH_MLU #include "aten.h" +#include "common_mlu_helper.hpp" #define NFU_ALIGN_SIZE 128 diff --git a/setup.py b/setup.py index 0a879be2a6..f41531e13f 100644 --- a/setup.py +++ b/setup.py @@ -57,8 +57,9 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True): CommandLine: python -c "import setup; print(setup.parse_requirements())" """ - import sys from os.path import exists + + import sys require_fpath = fname def parse_line(line): @@ -185,6 +186,7 @@ def get_extensions(): if EXT_TYPE == 'parrots': ext_name = 'mmcv._ext' from parrots.utils.build_extension import Extension + # new parrots op impl do not use MMCV_USE_PARROTS # define_macros = [('MMCV_USE_PARROTS', None)] define_macros = [] @@ -287,28 +289,27 @@ def get_extensions(): extension = CUDAExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda')) + elif hasattr( + torch, + 'is_mlu_available') and torch.is_mlu_available() or os.getenv( + 'FORCE_MLU', '0') == '1': + from torch_mlu.utils.cpp_extension import MLUExtension + define_macros += [('MMCV_WITH_MLU', None)] + mlu_args = os.getenv('MMCV_MLU_ARGS') + extra_compile_args['cncc'] = [mlu_args] if mlu_args else [] + op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu') + extension = MLUExtension + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu')) else: - try: - if torch.is_mlu_available(): - from torch_mlu.utils.cpp_extension import MLUExtension - define_macros += [('MMCV_WITH_MLU', None)] - mlu_args = os.getenv('MMCV_MLU_ARGS') - extra_compile_args['cncc'] = [mlu_args] if mlu_args else [] - op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') - op_files += glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') - op_files += glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') - op_files += glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu') - extension = MLUExtension - include_dirs.append( - os.path.abspath('./mmcv/ops/csrc/common')) - else: - print('Cambricon Catch is not available!') - except AttributeError: - print(f'Compiling {ext_name} without CUDA') - op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ - glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') - extension = CppExtension - include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) + print(f'Compiling {ext_name} only with CPU') + op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + extension = CppExtension + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) # Since the PR (https://github.com/open-mmlab/mmcv/pull/1463) uses # c++14 features, the argument ['std=c++14'] must be added here. @@ -330,8 +331,8 @@ def get_extensions(): if EXT_TYPE == 'pytorch' and os.getenv('MMCV_WITH_ORT', '0') != '0': ext_name = 'mmcv._ext_ort' - from torch.utils.cpp_extension import library_paths, include_paths import onnxruntime + from torch.utils.cpp_extension import include_paths, library_paths library_dirs = [] libraries = [] include_dirs = [] From 9d007f951f1ba7b267a6cae026c86c1df577f28a Mon Sep 17 00:00:00 2001 From: zhouchenyang Date: Tue, 21 Dec 2021 21:29:10 +0800 Subject: [PATCH 20/30] [Improve] Improve the performance of roialign, nms and focalloss with MLU backend (#1572) * [Improve] Improve the performance of roialign with MLU backend * replace CHECK_MLU with CHECK_MLU_INPUT * [Improve] Improve the perf of nms and focallosssigmoid with MLU backend --- .../mlu/focal_loss_sigmoid_mlu_kernel.mlu | 462 +++--- mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu | 526 ++++++- .../csrc/common/mlu/roi_align_mlu_kernel.mlu | 1247 ++++++++++++----- mmcv/ops/csrc/pytorch/focal_loss.cpp | 14 +- .../pytorch/mlu/focal_loss_sigmoid_mlu.cpp | 69 +- mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp | 72 +- mmcv/ops/csrc/pytorch/roi_align.cpp | 20 +- 7 files changed, 1816 insertions(+), 594 deletions(-) diff --git a/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu index 028f6c0c9d..7624379b68 100644 --- a/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu @@ -24,8 +24,21 @@ __mlu_func__ void loadInput(char *nram_input, T *dram_input, const int32_t size, const int32_t dst_stride = 0, const int32_t src_stride = 0, const int32_t count = 1) { - __memcpy_async(nram_input, dram_input, size, GDRAM2NRAM, dst_stride, - src_stride, count - 1); + if (dst_stride == src_stride) { + __memcpy_async(nram_input, dram_input, size * count, GDRAM2NRAM); + } else { + __memcpy_async(nram_input, dram_input, size, GDRAM2NRAM, dst_stride, + src_stride, count - 1); + } +} + +template +__mlu_func__ void loadWeight(char *nram_input, T *dram_input, const int32_t t, + const int32_t c, const int32_t has_weight, + const int32_t partition_nc) { + if (has_weight && partition_nc && t >= 0 && t < c) { + __memcpy_async(nram_input, (T *)dram_input + t, sizeof(T), GDRAM2NRAM); + } } template @@ -33,152 +46,117 @@ __mlu_func__ void storeOutput(T *dram_output, char *nram_output, const int32_t size, const int32_t dst_stride = 0, const int32_t src_stride = 0, const int32_t count = 1) { - __memcpy_async(dram_output, nram_output, size, NRAM2GDRAM, dst_stride, - src_stride, count - 1); + if (dst_stride == src_stride) { + __memcpy_async(dram_output, nram_output, size * count, NRAM2GDRAM); + } else { + __memcpy_async(dram_output, nram_output, size, NRAM2GDRAM, dst_stride, + src_stride, count - 1); + } } template __mlu_func__ void compute(T *input, const int32_t *target, const T *weight, - const int32_t has_weight, const int32_t deal_num, - const int32_t n_seg, const int32_t C, float alpha, - float gamma, T *scalar_temp, T *tensor_max, - T *tensor_temp, T *output) { - const int32_t scalar_elem_num = NFU_ALIGN_SIZE / sizeof(T); - - // 0. n_max = max(0, x) - __nramset((T *)tensor_max, deal_num, (T)0); - __bang_cycle_maxequal((T *)tensor_max, (T *)tensor_max, (T *)input, deal_num, - deal_num); - - // 1. ln(1+e^x) = ln(e^(-max) + e^(x-max)) + max - __nramset((T *)scalar_temp, scalar_elem_num, (T)-1); - __bang_cycle_mul((T *)tensor_temp, (T *)tensor_max, (T *)scalar_temp, - deal_num, scalar_elem_num); - __bang_cycle_add((T *)output, (T *)input, (T *)tensor_temp, deal_num, - deal_num); - __bang_active_exphp((T *)output, (T *)output, deal_num); - __bang_active_exphp((T *)tensor_temp, (T *)tensor_temp, deal_num); - __bang_cycle_add((T *)output, (T *)output, (T *)tensor_temp, deal_num, - deal_num); - __bang_active_loghp((T *)output, (T *)output, deal_num); - __bang_cycle_add((T *)output, (T *)output, (T *)tensor_max, deal_num, - deal_num); - - // 2. temp = [1 + e^(-x)] ^ (-r) - __nramset((T *)scalar_temp, scalar_elem_num, (T)-1); - __bang_cycle_mul((T *)tensor_temp, (T *)input, (T *)scalar_temp, deal_num, - scalar_elem_num); - __bang_active_exphp((T *)tensor_temp, (T *)tensor_temp, deal_num); - - __nramset((T *)scalar_temp, scalar_elem_num, (T)1); - __bang_cycle_add((T *)tensor_temp, (T *)tensor_temp, (T *)scalar_temp, - deal_num, scalar_elem_num); - __bang_active_loghp((T *)tensor_temp, (T *)tensor_temp, deal_num); - - __nramset((T *)scalar_temp, scalar_elem_num, (T)(-gamma)); - __bang_cycle_mul((T *)tensor_temp, (T *)tensor_temp, (T *)scalar_temp, - deal_num, scalar_elem_num); - __bang_active_exphp((T *)tensor_temp, (T *)tensor_temp, deal_num); - - // 3.1 output: target != j - __nramset((T *)scalar_temp, scalar_elem_num, (T)(1 - alpha)); - __bang_cycle_mul((T *)output, (T *)output, (T *)scalar_temp, deal_num, - scalar_elem_num); - __bang_cycle_mul((T *)output, (T *)output, (T *)tensor_temp, deal_num, - deal_num); - - // 3.2 output: target == j - const int32_t c_align_size = PAD_UP((sizeof(T) * C), NFU_ALIGN_SIZE); + const int32_t has_weight, const int32_t partition_nc, + const int32_t deal_num, const int32_t n_seg, + const int32_t c, const int32_t c_seg, + const int32_t c_start_index, const float alpha, + const float gamma, T *compute_a, T *compute_b, + T *output) { + // set params + const int32_t c_num = + has_weight ? PAD_UP(c_seg, NFU_ALIGN_SIZE / sizeof(T)) : c_seg; + const int32_t c_end_index = c_start_index + c_seg; + const int32_t half_epsilon = 0x0400; + const T epsilon_f = + sizeof(T) == sizeof(float) ? FLT_MIN : *((half *)&half_epsilon); + + // 0. alpha_t * p_t^r = alpha * (1 - p) ^ gamma if t == c_i + // = (1 - alpha) * p ^ gamma if t != c_i + __nramset((T *)output, deal_num, (T)(1 - alpha)); + __bang_active_sigmoid((T *)compute_b, (T *)input, deal_num); for (int32_t i = 0; i < n_seg; ++i) { - const int32_t target_value = *((int32_t *)target + i); - if (target_value >= 0 && target_value < C) { - const int32_t offset = i * c_align_size + target_value * sizeof(T); - char *addr_input = (char *)input + offset; - char *addr_output = (char *)output + offset; - const float x = *(T *)addr_input; - const float p = 1. / (1. + exp(-x)); - *(T *)addr_output = -alpha * pow(1. - p, gamma) * log(fmax(p, FLT_MIN)); + const int32_t t = *((uint32_t *)target + i); + if (t >= c_start_index && t < c_end_index) { + const uint32_t index = i * c_num + t - c_start_index; + *((T *)input + index) = -1.0 * (*((T *)input + index)); + *((T *)compute_b + index) = 1.0 - (*((T *)compute_b + index)) + epsilon_f; + *((T *)output + index) = alpha; } } + if (sizeof(T) == sizeof(half)) { + __bang_half2float((float *)compute_a, (half *)compute_b, deal_num); + __bang_active_loghp((float *)compute_a, (float *)compute_a, deal_num); + __bang_mul_const((float *)compute_a, (float *)compute_a, (float)gamma, + deal_num); + __bang_active_exphp((float *)compute_a, (float *)compute_a, deal_num); + __bang_float2half_rd((half *)compute_a, (float *)compute_a, deal_num); + } else { + __bang_active_loghp((T *)compute_a, (T *)compute_b, deal_num); + __bang_mul_const((T *)compute_a, (T *)compute_a, (T)gamma, deal_num); + __bang_active_exphp((T *)compute_a, (T *)compute_a, deal_num); + } + __bang_mul((T *)output, (T *)compute_a, (T *)output, deal_num); + + // 1. max = max(0, -x) if t == c_i + // = max(0, x) if t != c_i + __nramset((T *)compute_b, deal_num, (T)0); + __bang_maxequal((T *)compute_b, (T *)compute_b, (T *)input, deal_num); - // with weight - if (has_weight > 0) { - int32_t row_num_elem = deal_num / n_seg; + // 2. -log(p_t) = ln(e^(-max)+ e^(-max-x) + max if t == c_i + // = ln(e^(-max)+ e^(-max+x) + max if t != c_i + __bang_mul_const((T *)compute_a, (T *)compute_b, (T)-1.0, deal_num); + __bang_add((T *)input, (T *)compute_a, (T *)input, deal_num); + + __bang_active_exphp((T *)compute_a, (T *)compute_a, deal_num); + __bang_active_exphp((T *)input, (T *)input, deal_num); + __bang_add((T *)compute_a, (T *)compute_a, (T *)input, deal_num); + __bang_active_loghp((T *)compute_a, (T *)compute_a, deal_num); + __bang_add((T *)input, (T *)compute_a, (T *)compute_b, deal_num); + + // 3. output = alpha_t * p_t^r * [-log(p_t)] + __bang_mul((T *)output, (T *)output, (T *)input, deal_num); + + // 4. with weight + if (has_weight) { for (int32_t i = 0; i < n_seg; ++i) { - const int32_t t = *((int32_t *)target + i); - __nramset((T *)scalar_temp, scalar_elem_num, *((T *)weight + t)); - __bang_cycle_mul((T *)output + i * row_num_elem, - (T *)output + i * row_num_elem, (T *)scalar_temp, - row_num_elem, scalar_elem_num); + int32_t t = *((int32_t *)target + i); + if (t >= 0 && t < c) { + t = partition_nc ? 0 : t; + __bang_mul_const((T *)output + i * c_num, (T *)output + i * c_num, + *((T *)weight + t), c_num); + } } } } template -__mlu_func__ void focalLossSigmoidForwardBlock( +__mlu_func__ void startPipeline( const T *input, const int32_t *target, const T *weight, - const int32_t row_num, const int32_t C, const float alpha, - const float gamma, T *output) { - /* - * NRAM partition - * |-----------------------------------------------------------------------| - * | scalar | - * |-----------------------------------------------------------------------| - * | weight | - * |------------------------------- COMPUTE -------------------------------| - * | | | - * | computeA | computeB | - * | | | - * |------------- PING ------------------------------- PONG ---------------| - * | | | - * | input | input | - * | | | - * |-----------------------------------|-----------------------------------| - * | | | - * | output | output | - * | | | - * |-----------------------------------|-----------------------------------| - * | target | target | - * |-----------------------------------|-----------------------------------| - * - * split_pipeline_num is 6: COMPUTE(computeA,computeB), PING(input,output), - * PONG(input,output). - * split_target_num is 2: PING(target), PONG(target). - */ - const int32_t c_align = PAD_UP(C, NFU_ALIGN_SIZE / sizeof(T)); - const int32_t c_align_size = c_align * sizeof(T); - const int32_t scalar_size = NFU_ALIGN_SIZE; - const int32_t weight_size = (weight != NULL) * c_align_size; - const int32_t split_pipeline_num = 6; - const int32_t split_target_num = 2; - - const int32_t remain_size = MAX_NRAM_SIZE - scalar_size - weight_size; - const int32_t n_seg = remain_size / (split_pipeline_num * c_align_size + - split_target_num * sizeof(int32_t)); - const int32_t deal_num = n_seg * c_align_size / sizeof(T); - const int32_t target_size = n_seg * sizeof(int32_t); - - // nram scalar,weight - char *nram_scalar = (char *)nram_buffer; - char *nram_weight = (char *)nram_scalar + scalar_size; - if (weight_size > 0) { - loadInput(nram_weight, (T *)weight, C * sizeof(T)); - __asm__ volatile("sync;"); - } - - // nram COMPUTE - const int32_t compute_size = 2 * c_align_size * n_seg; - char *nram_compute_a = (char *)nram_weight + weight_size; - char *nram_compute_b = (char *)nram_compute_a + c_align_size * n_seg; - - // nram PING/PONG - const int32_t pingpong_offset = (remain_size - compute_size) / 2; - char *nram_input = (char *)nram_compute_a + 2 * c_align_size * n_seg; - char *nram_output = (char *)nram_compute_a + 3 * c_align_size * n_seg; - char *nram_target = (char *)nram_compute_a + 4 * c_align_size * n_seg; - - const int32_t repeat = row_num / n_seg; - const int32_t remain = row_num % n_seg; + char *nram_compute_a, char *nram_compute_b, char *nram_input, + char *nram_target, char *nram_weight, char *nram_output, + const int32_t has_weight, const int32_t partition_nc, + const int32_t pingpong_offset, const int32_t pingpong_weight_offset, + const int32_t c_offset_num, const int32_t n, const int32_t n_seg, + const int32_t c, const int32_t c_seg, const float alpha, const float gamma, + T *output) { + // with offset + input = (T *)((char *)input + c_offset_num * sizeof(T)); + output = (T *)((char *)output + c_offset_num * sizeof(T)); + + const int32_t c_seg_align_num = PAD_UP(c_seg, NFU_ALIGN_SIZE / sizeof(T)); + const int32_t c_num = has_weight ? c_seg_align_num : c_seg; + const int32_t deal_num = PAD_UP(n_seg * c_num, NFU_ALIGN_SIZE / sizeof(T)); + const int32_t load_size = c_seg * sizeof(T); + const int32_t dram_stride = c * sizeof(T); + const int32_t nram_stride = c_num * sizeof(T); + + if (has_weight && !partition_nc) { + loadInput(nram_weight, (T *)weight, load_size, nram_stride, dram_stride, + 1); + __asm__ volatile("sync;\n\t"); + } + const int32_t repeat = n / n_seg; + const int32_t remain = n % n_seg; /* * Pipeline: The pipeline is processed in three stages: Load, Compute, Store. @@ -206,80 +184,214 @@ __mlu_func__ void focalLossSigmoidForwardBlock( // diagram of PINGPONG: L0 if (repeat > 0) { - loadInput(nram_input, (T *)input, C * sizeof(T), c_align * sizeof(T), - C * sizeof(T), n_seg); - loadInput(nram_target, (int32_t *)target, target_size); - __asm__ volatile("sync;"); + loadInput(nram_input, (T *)input, load_size, nram_stride, dram_stride, + n_seg); + loadInput(nram_target, (int32_t *)target, n_seg * sizeof(int32_t)); + loadWeight(nram_weight, (T *)weight, *((int32_t *)target), c, has_weight, + partition_nc); + __asm__ volatile("sync;\n\t"); } // diagram of PINGPONG: C0 and L1 if (repeat > 1) { - loadInput(nram_input + pingpong_offset, (T *)input + C * n_seg, - C * sizeof(T), c_align * sizeof(T), C * sizeof(T), n_seg); - loadInput(nram_target + pingpong_offset, (int32_t *)target + n_seg, - target_size); compute((T *)nram_input, (int32_t *)nram_target, (T *)nram_weight, - weight_size, deal_num, n_seg, C, alpha, gamma, (T *)nram_scalar, - (T *)nram_compute_a, (T *)nram_compute_b, (T *)nram_output); - __asm__ volatile("sync;"); + has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num, + alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)nram_output); + loadInput((char *)nram_input + pingpong_offset, (T *)input + c * n_seg, + load_size, nram_stride, dram_stride, n_seg); + loadInput((char *)nram_target + pingpong_offset, + (int32_t *)target + n_seg, n_seg * sizeof(int32_t)); + loadWeight((char *)nram_weight + pingpong_weight_offset, (T *)weight, + *((int32_t *)target + n_seg), c, has_weight, partition_nc); + __asm__ volatile("sync;\n\t"); } for (int32_t i = 0; i < repeat - 2; ++i) { - storeOutput((T *)output + i * C * n_seg, - nram_output + (i % 2) * pingpong_offset, C * sizeof(T), - C * sizeof(T), c_align * sizeof(T), n_seg); - loadInput(nram_input + (i % 2) * pingpong_offset, - (T *)input + (i + 2) * C * n_seg, C * sizeof(T), - c_align * sizeof(T), C * sizeof(T), n_seg); - loadInput(nram_target + (i % 2) * pingpong_offset, - (int32_t *)target + (i + 2) * n_seg, target_size); + storeOutput((T *)output + i * c * n_seg, + nram_output + (i % 2) * pingpong_offset, load_size, + dram_stride, nram_stride, n_seg); + loadInput((char *)nram_input + (i % 2) * pingpong_offset, + (T *)(input) + (i + 2) * c * n_seg, load_size, nram_stride, + dram_stride, n_seg); + loadInput((char *)nram_target + (i % 2) * pingpong_offset, + (int32_t *)target + (i + 2) * n_seg, + n_seg * sizeof(int32_t)); + loadWeight((char *)nram_weight + (i % 2) * pingpong_weight_offset, + (T *)weight, *((int32_t *)target + (i + 2) * n_seg), c, + has_weight, partition_nc); compute((T *)(nram_input + ((i + 1) % 2) * pingpong_offset), (int32_t *)(nram_target + ((i + 1) % 2) * pingpong_offset), - (T *)nram_weight, weight_size, deal_num, n_seg, C, alpha, gamma, - (T *)nram_scalar, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)(nram_weight + + partition_nc * ((i + 1) % 2) * pingpong_weight_offset), + has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num, + alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b, (T *)(nram_output + ((i + 1) % 2) * pingpong_offset)); - __asm__ volatile("sync;"); + __asm__ volatile("sync;\n\t"); } if (repeat > 1) { - storeOutput((T *)output + (repeat - 2) * C * n_seg, - nram_output + (repeat % 2) * pingpong_offset, C * sizeof(T), - C * sizeof(T), c_align * sizeof(T), n_seg); + storeOutput((T *)output + (repeat - 2) * c * n_seg, + (char *)nram_output + (repeat % 2) * pingpong_offset, + load_size, dram_stride, nram_stride, n_seg); } + if (remain > 0) { - loadInput(nram_input + (repeat % 2) * pingpong_offset, - (T *)input + repeat * C * n_seg, C * sizeof(T), - c_align * sizeof(T), C * sizeof(T), remain); - loadInput(nram_target + (repeat % 2) * pingpong_offset, + loadInput((char *)nram_input + (repeat % 2) * pingpong_offset, + (T *)input + repeat * c * n_seg, load_size, nram_stride, + dram_stride, remain); + loadInput((char *)nram_target + (repeat % 2) * pingpong_offset, (int32_t *)target + repeat * n_seg, remain * sizeof(int32_t)); + loadWeight((char *)nram_weight + (repeat % 2) * pingpong_weight_offset, + (T *)weight, *((int32_t *)target + repeat * n_seg), c, + has_weight, partition_nc); } + if (repeat > 0) { compute((T *)(nram_input + ((repeat - 1) % 2) * pingpong_offset), (int32_t *)(nram_target + ((repeat - 1) % 2) * pingpong_offset), - (T *)nram_weight, weight_size, deal_num, n_seg, C, alpha, gamma, - (T *)nram_scalar, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)(nram_weight + + partition_nc * ((repeat - 1) % 2) * pingpong_weight_offset), + has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num, + alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b, (T *)(nram_output + ((repeat - 1) % 2) * pingpong_offset)); } - __asm__ volatile("sync;"); + __asm__ volatile("sync;\n\t"); if (repeat > 0) { - storeOutput((T *)output + (repeat - 1) * C * n_seg, - nram_output + ((repeat - 1) % 2) * pingpong_offset, - C * sizeof(T), C * sizeof(T), c_align * sizeof(T), n_seg); + storeOutput((T *)output + (repeat - 1) * c * n_seg, + (char *)nram_output + ((repeat - 1) % 2) * pingpong_offset, + load_size, dram_stride, nram_stride, n_seg); } + if (remain > 0) { - int rem_deal_num = remain * c_align_size / sizeof(T); + int32_t rem_num = PAD_UP(remain * c_num, NFU_ALIGN_SIZE / sizeof(T)); compute((T *)(nram_input + (repeat % 2) * pingpong_offset), (int32_t *)(nram_target + (repeat % 2) * pingpong_offset), - (T *)nram_weight, weight_size, rem_deal_num, remain, C, alpha, - gamma, (T *)nram_scalar, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)(nram_weight + + partition_nc * (repeat % 2) * pingpong_weight_offset), + has_weight, partition_nc, rem_num, remain, c, c_seg, c_offset_num, + alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b, (T *)(nram_output + (repeat % 2) * pingpong_offset)); - __asm__ volatile("sync;"); + __asm__ volatile("sync;\n\t"); + + storeOutput((T *)output + repeat * c * n_seg, + (char *)nram_output + (repeat % 2) * pingpong_offset, + load_size, dram_stride, nram_stride, remain); + } + __asm__ volatile("sync;\n\t"); +} - storeOutput((T *)output + repeat * C * n_seg, - nram_output + (repeat % 2) * pingpong_offset, C * sizeof(T), - C * sizeof(T), c_align * sizeof(T), remain); +template +__mlu_func__ void focalLossSigmoidForwardBlock( + const T *input, const int32_t *target, const T *weight, const int32_t n, + const int32_t c, const float alpha, const float gamma, T *output) { + /* + * NRAM partition + * |-----------------------------------------------------------------------| + * | weight | + * |------------------------------- COMPUTE -------------------------------| + * | | | + * | computeA | computeB | + * | | | + * |------------- PING ------------------------------- PONG ---------------| + * | | | + * | input | input | + * | | | + * |-----------------------------------|-----------------------------------| + * | | | + * | output | output | + * | | | + * |-----------------------------------|-----------------------------------| + * | target | target | + * |-----------------------------------|-----------------------------------| + * + * split_pipeline_num is 6: COMPUTE(computeA,computeB), PING(input,output), + * PONG(input,output). + * split_target_num is 2: PING(target), PONG(target). + * weight is not NULL: + * The nram-size of weight is equal to c_align_size when partition input-N. + * The nram-size of weight is equal to NFU_ALIGN_SIZE when partition + * input-NC. + */ + + // calculate threshold of c + const int32_t split_pipeline_num = 6; + const int32_t split_target_num = 2; + const int32_t has_weight = weight != NULL; + const int32_t threshold_c = + PAD_DOWN((MAX_NRAM_SIZE - split_target_num * sizeof(int32_t)) / + (split_pipeline_num + has_weight), + NFU_ALIGN_SIZE) / + sizeof(T); + const int32_t c_align = PAD_UP(c, NFU_ALIGN_SIZE / sizeof(T)); + const int32_t c_align_size = c_align * sizeof(T); + + if (c <= threshold_c) { + // partition inputN + int32_t c_num = c; + int32_t reservered_align_size = + (split_target_num + split_pipeline_num) * NFU_ALIGN_SIZE; + int32_t weight_size = 0; + if (has_weight) { + c_num = c_align; + reservered_align_size = split_target_num * NFU_ALIGN_SIZE; + weight_size = c_align_size; + } + + const int32_t remain_size = + MAX_NRAM_SIZE - weight_size - reservered_align_size; + const int32_t n_seg = + remain_size / (split_pipeline_num * c_num * sizeof(T) + + split_target_num * sizeof(int32_t)); + const int32_t split_pipeline_size = + PAD_UP(c_num * n_seg * sizeof(T), NFU_ALIGN_SIZE); + const int32_t compute_size = 2 * split_pipeline_size; + const int32_t pingpong_offset = (MAX_NRAM_SIZE - weight_size - compute_size) / 2; + + char *nram_weight = (char *)nram_buffer; + char *nram_compute_a = nram_weight + has_weight * c_align_size; + char *nram_compute_b = nram_compute_a + split_pipeline_size; + char *nram_input = nram_compute_b + split_pipeline_size; + char *nram_output = nram_input + split_pipeline_size; + char *nram_target = nram_output + split_pipeline_size; + + startPipeline(input, target, weight, nram_compute_a, nram_compute_b, + nram_input, nram_target, nram_weight, nram_output, + has_weight, 0, pingpong_offset, 0, 0, n, n_seg, c, c, + alpha, gamma, output); + } else { + // partition inputNC + const int32_t weight_size = has_weight * NFU_ALIGN_SIZE; + const int32_t remain_size = MAX_NRAM_SIZE - weight_size; + const int32_t split_pipeline_size = PAD_DOWN( + (remain_size - split_target_num * NFU_ALIGN_SIZE) / split_pipeline_num, + NFU_ALIGN_SIZE); + const int32_t c_seg = split_pipeline_size / sizeof(T); + const int32_t n_seg = 1; + const int32_t compute_size = 2 * split_pipeline_size; + const int32_t pingpong_offset = (MAX_NRAM_SIZE - weight_size - compute_size) / 2; + const int32_t pingpong_weight_offset = weight_size / 2; + + char *nram_weight = (char *)nram_buffer; + char *nram_compute_a = nram_weight + weight_size; + char *nram_compute_b = nram_compute_a + split_pipeline_size; + char *nram_input = nram_compute_b + split_pipeline_size; + char *nram_output = nram_input + split_pipeline_size; + char *nram_target = nram_output + split_pipeline_size; + + const int32_t loop_num = (c + c_seg - 1) / c_seg; + const int32_t partition_nc = 1; + for (int32_t i = 0; i < loop_num; ++i) { + const int32_t c_index = i * c_seg; + const int32_t c_seg_curr = i == (loop_num - 1) ? c - c_index : c_seg; + startPipeline(input, target, weight, nram_compute_a, nram_compute_b, + nram_input, nram_target, nram_weight, nram_output, + has_weight, partition_nc, pingpong_offset, + pingpong_weight_offset, c_index, n, n_seg, c, c_seg_curr, + alpha, gamma, output); + } } } diff --git a/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu index 1095da870c..7cb16bb100 100644 --- a/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (C) 2021 by Cambricon. + * Copyright (C) 2021 Cambricon. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF @@ -15,6 +15,8 @@ #define COORD_DIM (4) #define MEMORY_CORE (0x80) #define INFO_NUM (5) // 5 means x1, x2, y1, y2 and score +#define REDUCE_NUM \ + (7) // score, x1, y1, x2, y2, max_index (reserve 2 num for half-type input) #define SIZE_NRAM_BUF (MAX_NRAM_SIZE + REM_FOR_STACK - 62 * 1024) #define SIZE_SRAM_BUF (MAX_SRAM_SIZE) @@ -551,7 +553,7 @@ __mlu_func__ void nms_detection( } // for keepNum } -__mlu_global__ void MLUKernelNMS( +__mlu_global__ void MLUUnion1KernelNMS( const void *input_boxes, const void *input_confidence, const int input_num_boxes, const int input_stride, const int max_output_size, const float iou_threshold, @@ -635,15 +637,525 @@ __mlu_global__ void MLUKernelNMS( } } +template +__mlu_func__ void nms_detection_ux( + int32_t *loop_end_flag, uint32_t &output_box_num, OUT_DT *output_dram, + IN_DT *score_data, const IN_DT *boxes_data, const Addr input_ram, + const int input_layout, const int input_num_boxes, const int input_stride, + const int max_output_size, const float thresh_iou, const float thresh_score, + const float offset, const int output_mode, const int algo) { + loop_end_flag[0] = 0; + IN_DT *sram = (IN_DT *)sram_buffer; + + // score, x1, y1, x2, y2, inter_x1, inter_y1, inter_x2, inter_y2 + int nms_buffer_count1 = 9; + // temp nram buffer to store selected target. + int nram_save_limit_count = 256; + float div_thresh_iou = 1.0 / thresh_iou; + + // input data ptr + IN_DT *input_score_ptr; + const IN_DT *input_x1_ptr; + const IN_DT *input_y1_ptr; + const IN_DT *input_x2_ptr; + const IN_DT *input_y2_ptr; + input_score_ptr = score_data; + input_x1_ptr = boxes_data; + input_y1_ptr = input_x1_ptr + input_stride; + input_x2_ptr = input_y1_ptr + input_stride; + input_y2_ptr = input_x2_ptr + input_stride; + + int limit = 0; // find limit when GDRAM or SRAM + int max_seg_pad = 0; // the max length every repeat + int repeat = 0; + int remain = 0; + int remain_pad = 0; + int nram_save_count = 0; + + if (output_mode == 0) { + limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) - + nram_save_limit_count * sizeof(OUT_DT)) / + (nms_buffer_count1 * sizeof(IN_DT)); + } else { + limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) - + nram_save_limit_count * INFO_NUM * sizeof(OUT_DT)) / + (nms_buffer_count1 * sizeof(IN_DT)); + } + + // data split + int avg_cluster = input_num_boxes / clusterDim; + int rem_cluster = input_num_boxes % clusterDim; + int len_cluster = avg_cluster + (clusterId < rem_cluster ? 1 : 0); + int cluster_offset = avg_cluster * clusterId + + (clusterId <= rem_cluster ? clusterId : rem_cluster); + + int avg_core = len_cluster / coreDim; + int rem_core = len_cluster % coreDim; + int len_core = avg_core + (coreId < rem_core ? 1 : 0); + int core_offset = + avg_core * coreId + (coreId <= rem_core ? coreId : rem_core); + int input_offset = cluster_offset + core_offset; + + max_seg_pad = PAD_DOWN(limit, NMS_SIZE); + + // core 0 of each cluster calculate the max score index + int max_index_avg_core = input_num_boxes / clusterDim; + int max_index_rem_core = input_num_boxes % clusterDim; + int max_index_len_core = + max_index_avg_core + (clusterId < max_index_rem_core ? 1 : 0); + int max_index_input_offset = + max_index_avg_core * clusterId + + (clusterId <= max_index_rem_core ? clusterId : max_index_rem_core); + repeat = max_index_len_core / max_seg_pad; + remain = max_index_len_core % max_seg_pad; + remain_pad = PAD_UP(remain, NMS_SIZE); + + // if datatype is fp16, we should cvt to fp32 when compute iou + int max_seg_iou_compute = + PAD_DOWN(max_seg_pad / (sizeof(float) / sizeof(IN_DT)), NMS_SIZE); + int repeat_iou_compute = len_core / max_seg_iou_compute; + int remain_iou_compute = len_core % max_seg_iou_compute; + int remain_pad_iou_compute = PAD_UP(remain_iou_compute, NMS_SIZE); + + // init the nram ptr + IN_DT *score = (IN_DT *)nram_buffer; + IN_DT *x1 = score + max_seg_pad; + IN_DT *y1 = x1 + max_seg_pad; + IN_DT *x2 = y1 + max_seg_pad; + IN_DT *y2 = x2 + max_seg_pad; + IN_DT *inter_x1 = y2 + max_seg_pad; + IN_DT *inter_y1 = inter_x1 + max_seg_pad; + IN_DT *inter_x2 = inter_y1 + max_seg_pad; + IN_DT *inter_y2 = inter_x2 + max_seg_pad; + IN_DT *max_box = inter_y2 + max_seg_pad; // the max score, x1, y1, x2, y2 + OUT_DT *nram_save = + (OUT_DT *)((char *)max_box + + NFU_ALIGN_SIZE); // offset two line from max_box + + mluMemcpyDirection_t input_load_dir = SRAM2NRAM; + mluMemcpyDirection_t input_store_dir = NRAM2SRAM; + input_load_dir = (input_ram == SRAM) ? SRAM2NRAM : GDRAM2NRAM; + input_store_dir = (input_ram == SRAM) ? NRAM2SRAM : NRAM2GDRAM; + + for (int keep = 0; keep < max_output_size; + keep++) { // loop until the max_score <= 0 + __sync_all(); + + /******FIND MAX START******/ + int max_index = 0; + int global_max_index = 0; // for Ux + float max_area = 0; // the max socre area + max_box[0] = 0; // init 0 + + if (coreId == 0) { + for (int i = 0; i <= repeat; i++) { + if (i == repeat && remain == 0) { + break; + } + + int seg_len = (i == repeat) + ? remain_pad + : max_seg_pad; // the length every nms compute + // check seg_len exceeds the limit of fp16 or not. 65536 is the largest + // num + // that fp16 could express. + if (sizeof(IN_DT) == sizeof(half) && seg_len > 65536) { + return; + } + int cpy_len = (i == repeat) + ? remain + : max_seg_pad; // the length every nms memcpy + + /******NMS LOAD START******/ + __bang_write_zero(score, seg_len); + __memcpy(score, + input_score_ptr + max_index_input_offset + i * max_seg_pad, + cpy_len * sizeof(IN_DT), input_load_dir, + cpy_len * sizeof(IN_DT), cpy_len * sizeof(IN_DT), 0); + + /******NMS LOAD END******/ + + __bang_max(inter_x1, score, seg_len); + if (inter_x1[0] > max_box[0]) { + max_box[0] = inter_x1[0]; + if (sizeof(IN_DT) == sizeof(half)) { + max_index = + ((uint16_t *)inter_x1)[1] + max_index_input_offset + + i * max_seg_pad; // offset start from head of input_data + } else if (sizeof(IN_DT) == sizeof(float)) { + max_index = + ((uint32_t *)inter_x1)[1] + max_index_input_offset + + i * max_seg_pad; // offset start from head of input_data + } + } + } // for repeat + + // the max box's x1, y1, x2, y2 on every cluster + max_box[1] = input_x1_ptr[max_index]; + max_box[2] = input_y1_ptr[max_index]; + max_box[3] = input_x2_ptr[max_index]; + max_box[4] = input_y2_ptr[max_index]; + ((uint32_t *)(max_box + 5))[0] = max_index; + // copy max box info to sram + __memcpy(sram, max_box, REDUCE_NUM * sizeof(IN_DT), NRAM2SRAM); + } + __sync_all(); + // copy all partial max to the sram of cluster 0 + if (clusterId != 0) { + __memcpy(sram + REDUCE_NUM * clusterId, sram, REDUCE_NUM * sizeof(IN_DT), + SRAM2SRAM, 0); + } + __sync_all(); + + // reduce between clusters to get the global max box + if (clusterId == 0) { + if (coreId == 0) { + __bang_write_zero(inter_x1, NMS_SIZE); + __memcpy(inter_x1, sram, sizeof(IN_DT), SRAM2NRAM, sizeof(IN_DT), + REDUCE_NUM * sizeof(IN_DT), clusterDim - 1); + __bang_max(max_box, inter_x1, NMS_SIZE); + int max_cluster = (sizeof(IN_DT) == sizeof(half)) + ? ((uint16_t *)max_box)[1] + : ((uint32_t *)max_box)[1]; + __memcpy(max_box, sram + max_cluster * REDUCE_NUM, + REDUCE_NUM * sizeof(IN_DT), SRAM2NRAM); + __memcpy(sram, max_box, REDUCE_NUM * sizeof(IN_DT), NRAM2SRAM); + } + __sync_cluster(); + if (coreId == 0x80 && clusterDim > 1) { + // broadcast global max box to each cluster's sram + for (int cluster_idx = 1; cluster_idx < clusterDim; ++cluster_idx) { + __memcpy(sram, sram, REDUCE_NUM * sizeof(IN_DT), SRAM2SRAM, + cluster_idx); + } + } + __sync_cluster(); + } + __sync_all(); + + // copy the global max box to max_box + __memcpy(max_box, sram, REDUCE_NUM * sizeof(IN_DT), SRAM2NRAM); + if (algo == 0 || offset == 0.0) { + max_area = ((float)max_box[3] - (float)max_box[1]) * + ((float)max_box[4] - (float)max_box[2]); + } else { + max_area = ((float)max_box[3] - (float)max_box[1] + offset) * + ((float)max_box[4] - (float)max_box[2] + offset); + } + global_max_index = ((uint32_t *)(max_box + 5))[0]; + if (coreId != 0x80) { + input_score_ptr[global_max_index] = 0; + } + // by now, we get: max_score|max_index|max_box|max_area + /******FIND MAX END******/ + + /******NMS STORE START******/ + // store to nram + if (float(max_box[0]) > thresh_score) { + OUT_DT *save_ptr; + int save_offset = 0; + int save_str_num = 0; + save_ptr = nram_save; + save_offset = nram_save_count; + save_str_num = nram_save_limit_count; + if (clusterId == 0 && coreId == 0) { + if (output_mode == 0) { // index1, index2, ... + save_ptr[save_offset] = ((uint32_t *)(max_box + INFO_NUM))[0]; + } else if (output_mode == 1) { // score, x1, y1, x2, y2 + __memcpy(save_ptr + save_offset * INFO_NUM, max_box, + INFO_NUM * sizeof(IN_DT), NRAM2NRAM, + INFO_NUM * sizeof(IN_DT), INFO_NUM * sizeof(IN_DT), 0); + } else if (output_mode == 2) { // score---, x1---, y1---, x2---, y2--- + __memcpy(save_ptr + save_offset, max_box, 1 * sizeof(IN_DT), + NRAM2NRAM, save_str_num * sizeof(IN_DT), 1 * sizeof(IN_DT), + 4); + } + } + nram_save_count++; + output_box_num++; + } + + // store to sram/gdram + if (output_box_num != 0) { + if ((nram_save_count == nram_save_limit_count) || + (float(max_box[0]) <= thresh_score) || keep == max_output_size - 1) { + if (nram_save_count != 0) { + if (clusterId == 0 && coreId == 0) { + if (output_mode == 0) { // index1, index2, ... + pvLock(); + __memcpy(output_dram, nram_save, + nram_save_count * sizeof(uint32_t), NRAM2GDRAM); + pvUnlock(); + output_dram += nram_save_count; + } else if (output_mode == 1) { // score, x1, y1, x2, y2 + pvLock(); + __memcpy(output_dram, nram_save, + nram_save_count * INFO_NUM * sizeof(IN_DT), NRAM2GDRAM); + pvUnlock(); + output_dram += nram_save_count * INFO_NUM; + } else if (output_mode == + 2) { // score---, x1---, y1---, x2---, y2--- + pvLock(); + __memcpy(output_dram, nram_save, nram_save_count * sizeof(IN_DT), + NRAM2GDRAM, max_output_size * sizeof(IN_DT), + nram_save_limit_count * sizeof(IN_DT), 4); + pvUnlock(); + output_dram += nram_save_count; + } + nram_save_count = 0; + } + } + } // if move data nram->sram/gdram + } // if dst + + if (float(max_box[0]) <= thresh_score) { + if (clusterId == 0 && coreId == 0) { + loop_end_flag[0] = 1; // dram + } + } + __sync_all(); + if (loop_end_flag[0] == 1) { + break; + } + /******NMS STORE END******/ + + // To solve fp16 accuracy, we convert fp16 to fp32 to calculate IoU. + for (int i = 0; i <= repeat_iou_compute; i++) { + if (i == repeat_iou_compute && remain_iou_compute == 0) { + break; + } + int seg_len = (i == repeat_iou_compute) ? remain_pad_iou_compute + : max_seg_iou_compute; + int cpy_len = + (i == repeat_iou_compute) ? remain_iou_compute : max_seg_iou_compute; + + /******NMS LOAD START******/ + __nramset((float *)score, seg_len, 0.0f); + int dt_offset = 0; + if (sizeof(IN_DT) == sizeof(float)) { + __memcpy(score, input_score_ptr + input_offset + i * max_seg_pad, + cpy_len * sizeof(IN_DT), input_load_dir, + cpy_len * sizeof(IN_DT), cpy_len * sizeof(IN_DT), 0); + dt_offset = 0; + } else if (sizeof(IN_DT) == sizeof(half)) { + __nramset(x1, seg_len, half(0)); + __memcpy(x1, input_score_ptr + input_offset + i * max_seg_iou_compute, + cpy_len * sizeof(IN_DT), input_load_dir, + cpy_len * sizeof(IN_DT), cpy_len * sizeof(IN_DT), 0); + __bang_half2float((float *)score, (half *)x1, seg_len); + dt_offset = max_seg_iou_compute; + } + + __memcpy(x1 + dt_offset, + input_x1_ptr + input_offset + i * max_seg_iou_compute, + cpy_len * sizeof(IN_DT), input_load_dir, + max_seg_pad * sizeof(IN_DT), input_num_boxes * sizeof(IN_DT), 3); + /******NMS LOAD END******/ + + /******NMS COMPUTE START******/ + if (sizeof(IN_DT) == sizeof(half)) { + __bang_half2float((float *)x1, (half *)x1 + max_seg_iou_compute, + seg_len); + __bang_half2float((float *)y1, (half *)y1 + max_seg_iou_compute, + seg_len); + __bang_half2float((float *)x2, (half *)x2 + max_seg_iou_compute, + seg_len); + __bang_half2float((float *)y2, (half *)y2 + max_seg_iou_compute, + seg_len); + } + // 1、 compute IOU + // get the area_I + __nramset((float *)inter_y1, seg_len, float(max_box[1])); // max_x1 + __bang_maxequal((float *)inter_x1, (float *)x1, (float *)inter_y1, + seg_len); // inter_x1 + __nramset((float *)inter_y2, seg_len, float(max_box[3])); // max_x2 + __bang_minequal((float *)inter_x2, (float *)x2, (float *)inter_y2, + seg_len); // inter_x2 + __bang_sub((float *)inter_x1, (float *)inter_x2, (float *)inter_x1, + seg_len); + if (algo == 1 && offset != 0.0) { + __bang_add_const((float *)inter_x1, (float *)inter_x1, offset, seg_len); + } + __bang_active_relu((float *)inter_x1, (float *)inter_x1, + seg_len); // inter_w + __nramset((float *)inter_x2, seg_len, float(max_box[2])); // max_y1 + __bang_maxequal((float *)inter_y1, (float *)y1, (float *)inter_x2, + seg_len); // inter_y1 + __nramset((float *)inter_x2, seg_len, float(max_box[4])); // max_y2 + __bang_minequal((float *)inter_y2, (float *)y2, (float *)inter_x2, + seg_len); // inter_y2 + __bang_sub((float *)inter_y1, (float *)inter_y2, (float *)inter_y1, + seg_len); + if (algo == 1 && offset != 0.0) { + __bang_add_const((float *)inter_y1, (float *)inter_y1, offset, seg_len); + } + __bang_active_relu((float *)inter_y1, (float *)inter_y1, + seg_len); // inter_h + __bang_mul((float *)inter_x1, (float *)inter_x1, (float *)inter_y1, + seg_len); // area_I + // get the area of input_box: area = (x2 - x1) * (y2 - y1); + __bang_sub((float *)inter_y1, (float *)x2, (float *)x1, seg_len); + __bang_sub((float *)inter_y2, (float *)y2, (float *)y1, seg_len); + if (algo == 1 && offset != 0.0) { + __bang_add_const((float *)inter_y1, (float *)inter_y1, offset, seg_len); + __bang_add_const((float *)inter_y2, (float *)inter_y2, offset, seg_len); + } + __bang_mul((float *)inter_x2, (float *)inter_y1, (float *)inter_y2, + seg_len); // area + // get the area_U: area + max_area - area_I + __bang_add_const((float *)inter_x2, (float *)inter_x2, float(max_area), + seg_len); + __bang_sub((float *)inter_x2, (float *)inter_x2, (float *)inter_x1, + seg_len); // area_U + // 2、 select the box + // if IOU greater than thres, set the score to zero, abort it: area_U > + // area_I * (1 / thresh)? + if (thresh_iou > 0.0) { + __bang_mul_const((float *)inter_x1, (float *)inter_x1, div_thresh_iou, + seg_len); + } else { + __bang_mul_const((float *)inter_x2, (float *)inter_x2, thresh_iou, + seg_len); + } + __bang_ge((float *)inter_x1, (float *)inter_x2, (float *)inter_x1, + seg_len); + __bang_mul((float *)score, (float *)score, (float *)inter_x1, seg_len); + /******NMS COMPUTE END******/ + + if (sizeof(IN_DT) == 2) { + __bang_float2half_rd((half *)score, (float *)score, seg_len); + } + pvLock(); + __memcpy(input_score_ptr + input_offset + i * max_seg_iou_compute, score, + cpy_len * sizeof(IN_DT), input_store_dir, + cpy_len * sizeof(IN_DT), cpy_len * sizeof(IN_DT), 0); + pvUnlock(); + } // for repeat + } // for max_output_size +} + +__mlu_global__ void MLUUionXKernelNMS( + const void *input_boxes, const void *input_confidence, + const int input_num_boxes, const int input_layout, const int input_stride, + const int max_output_size, const float iou_threshold, + const float confidence_threshold, const float offset, + const cnrtDataType_t data_type_input, const int output_mode, const int algo, + void *workspace, void *result_num, void *output) { + int input_dwidth = (data_type_input == CNRT_FLOAT32) ? 4 : 2; + int32_t *loop_end_flag = + (int32_t *)((char *)workspace + + INFO_NUM * input_num_boxes * input_dwidth); + int reduce_sram_size = NFU_ALIGN_SIZE * REDUCE_NUM * input_dwidth; + int availbale_sram_size = SIZE_SRAM_BUF - reduce_sram_size; + + int cluster_score_size = input_num_boxes * input_dwidth; + int cluster_boxes_size = input_num_boxes * 4 * input_dwidth; + char *sram_score = (char *)sram_buffer + reduce_sram_size; + char *sram_boxes = + (char *)sram_buffer + reduce_sram_size + cluster_score_size; + Addr input_ram = GDRAM; + if ((cluster_score_size + cluster_boxes_size) < availbale_sram_size) { + input_ram = SRAM; + __memcpy(sram_score, input_confidence, cluster_score_size, GDRAM2SRAM); + __memcpy(sram_boxes, input_boxes, cluster_boxes_size, GDRAM2SRAM); + } else { + __memcpy(workspace, input_confidence, cluster_score_size, GDRAM2GDRAM); + } + __sync_cluster(); + uint32_t output_box_num = 0; + if (output_mode == 0) { + uint32_t *output_dram = (uint32_t *)output; + switch (data_type_input) { + default: { return; } + case CNRT_FLOAT16: { + half *score_data; + half *boxes_data; + score_data = + (input_ram == SRAM) ? (half *)sram_score : (half *)workspace; + boxes_data = + (input_ram == SRAM) ? (half *)sram_boxes : (half *)input_boxes; + nms_detection_ux(loop_end_flag, output_box_num, output_dram, score_data, + boxes_data, input_ram, input_layout, input_num_boxes, + input_stride, max_output_size, iou_threshold, + confidence_threshold, offset, output_mode, algo); + ((uint32_t *)result_num)[0] = output_box_num; + }; break; + case CNRT_FLOAT32: { + float *score_data; + float *boxes_data; + score_data = + (input_ram == SRAM) ? (float *)sram_score : (float *)workspace; + boxes_data = + (input_ram == SRAM) ? (float *)sram_boxes : (float *)input_boxes; + nms_detection_ux(loop_end_flag, output_box_num, output_dram, score_data, + boxes_data, input_ram, input_layout, input_num_boxes, + input_stride, max_output_size, iou_threshold, + confidence_threshold, offset, output_mode, algo); + ((uint32_t *)result_num)[0] = output_box_num; + }; break; + } + } else { + switch (data_type_input) { + default: { return; } + case CNRT_FLOAT16: { + half *output_dram = (half *)output; + half *score_data; + half *boxes_data; + score_data = + (input_ram == SRAM) ? (half *)sram_score : (half *)workspace; + boxes_data = + (input_ram == SRAM) ? (half *)sram_boxes : (half *)input_boxes; + nms_detection_ux(loop_end_flag, output_box_num, output_dram, score_data, + boxes_data, input_ram, input_layout, input_num_boxes, + input_stride, max_output_size, iou_threshold, + confidence_threshold, offset, output_mode, algo); + ((uint32_t *)result_num)[0] = output_box_num; + }; break; + case CNRT_FLOAT32: { + float *output_dram = (float *)output; + float *score_data; + float *boxes_data; + score_data = + (input_ram == SRAM) ? (float *)sram_score : (float *)workspace; + boxes_data = + (input_ram == SRAM) ? (float *)sram_boxes : (float *)input_boxes; + nms_detection_ux(loop_end_flag, output_box_num, output_dram, score_data, + boxes_data, input_ram, input_layout, input_num_boxes, + input_stride, max_output_size, iou_threshold, + confidence_threshold, offset, output_mode, algo); + ((uint32_t *)result_num)[0] = output_box_num; + }; break; + } + } +} + void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t data_type_input, const void *boxes_ptr, const void *scores_ptr, const int input_num_boxes, const int input_stride, const int max_output_boxes, const float iou_threshold, const float offset, void *workspace_ptr, void *output_size_ptr, void *output_ptr) { - MLUKernelNMS<<>>( - boxes_ptr, scores_ptr, input_num_boxes, input_stride, max_output_boxes, - iou_threshold, /*confidence_threshold=*/0.0, /*output_mode=*/0, - /*input_layout=*/0, workspace_ptr, output_size_ptr, output_ptr, - data_type_input, offset, /*algo=*/1); + switch (k_type) { + default: { return; } + case CNRT_FUNC_TYPE_BLOCK: + case CNRT_FUNC_TYPE_UNION1: { + MLUUnion1KernelNMS<<>>( + boxes_ptr, scores_ptr, input_num_boxes, input_stride, + max_output_boxes, iou_threshold, /*confidence_threshold=*/0.0, + /*output_mode=*/0, + /*input_layout=*/1, workspace_ptr, output_size_ptr, output_ptr, + data_type_input, offset, /*algo=*/1); + }; break; + case CNRT_FUNC_TYPE_UNION2: + case CNRT_FUNC_TYPE_UNION4: + case CNRT_FUNC_TYPE_UNION8: + case CNRT_FUNC_TYPE_UNION16: { + MLUUionXKernelNMS<<>>( + boxes_ptr, scores_ptr, input_num_boxes, /*input_layout=*/1, + input_stride, max_output_boxes, iou_threshold, + /*confidence_threshold=*/0.0, offset, data_type_input, + /*output_mode=*/0, /*algo=*/1, workspace_ptr, output_size_ptr, + output_ptr); + }; break; + } } diff --git a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu index e11aa4c575..55df914ab0 100644 --- a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu @@ -14,8 +14,7 @@ __nram__ char buffer[MAX_NRAM_SIZE]; #define ALIGN_SIZE 64 -#define MAX_ELEMENTS_FLOAT (50 * 1024) -#define MAX_ELEMENTS_HALF (100 * 1024) +#define BUFFER_SIZE (MAX_NRAM_SIZE * 480 / 512) #define ROI_OFFSET 5 #define SAMPLING_NUM 4 @@ -24,59 +23,101 @@ __nram__ char buffer[MAX_NRAM_SIZE]; namespace forward { template -__mlu_func__ void bilinearInterpolate(int input_height, int input_width, - float y, float x, T *w1, T *w2, T *w3, - T *w4, int *x_low, int *x_high, - int *y_low, int *y_high, int *empty, - T zero_sign) { - // deal with cases that inverse elements are of feature map boundary - if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { - *empty = 1; - return; - } +__mlu_func__ void bilinearInterpolate( + T *tmp_sum, T *nram_in, T *offset_bottom_data, const int roi_bin_grid_h, + const int roi_bin_grid_w, const T bin_size_h, const T bin_size_w, + const int input_height, const int input_width, const int channels, + const int channel_align, const int cyc_channel, T y_pre, T x_pre, + T zero_sign_tmp, bool is_normal_c, int index) { + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + T y = (y_pre + ((T)(iy + 0.5) * bin_size_h) / (T)(roi_bin_grid_h)) <= 0.0 + ? 0.0 + : (y_pre + ((T)(iy + 0.5) * bin_size_h) / (T)(roi_bin_grid_h)); + int y_low = int(y); + int y_high; + if (y_low >= input_height - 1) { + y_high = y_low = input_height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + T ly = y - y_low; + T hy = 1.0 - ly; + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + T x = (x_pre + ((ix + 0.5) * bin_size_w) / (T)(roi_bin_grid_w)) <= 0.0 + ? 0.0 + : (x_pre + ((ix + 0.5) * bin_size_w) / (roi_bin_grid_w)); + T zero_sign = + (T)(x >= -1.0 && x <= input_width && y >= -1.0 && y <= input_height) * + zero_sign_tmp; + int x_low = int(x); + int x_high; + if (x_low >= input_width - 1) { + x_high = x_low = input_width - 1; + x = T(x_low); + } else { + x_high = x_low + 1; + } + T lx = x - x_low; + T hx = 1.0 - lx; - if (y <= 0) y = 0; - if (x <= 0) x = 0; + T w1 = hy * hx * zero_sign; + T w2 = hy * lx * zero_sign; + T w3 = ly * hx * zero_sign; + T w4 = ly * lx * zero_sign; - *y_low = int(y); - *x_low = int(x); + // load + int cpy_len = (x_high - x_low) * channels; + int temp_size = cyc_channel < (channels - index * cyc_channel) + ? cyc_channel + : channels - index * cyc_channel; + int cpy_size = is_normal_c ? channels * sizeof(T) : temp_size * sizeof(T); - if (*y_low >= input_height - 1) { - *y_high = *y_low = input_height - 1; - y = (T)(*y_low); - } else { - *y_high = *y_low + 1; - } + int32_t offset1 = (y_low * input_width + x_low) * channels; + int32_t offset2 = (y_high * input_width + x_low) * channels; - if (*x_low >= input_width - 1) { - *x_high = *x_low = input_width - 1; - x = (T)(*x_low); - } else { - *x_high = *x_low + 1; - } + T *tmp1 = is_normal_c + ? offset_bottom_data + offset1 + : offset_bottom_data + offset1 + cyc_channel * index; + T *tmp2 = is_normal_c + ? offset_bottom_data + offset2 + : offset_bottom_data + offset2 + cyc_channel * index; - T ly = y - *y_low; - T lx = x - *x_low; - T hy = 1.0 - ly; - T hx = 1.0 - lx; + T *tmp_cyc1 = nram_in; + T *tmp_cyc2 = nram_in + cyc_channel; + T *tmp_cyc3 = nram_in + cyc_channel * 2; + T *tmp_cyc4 = nram_in + cyc_channel * 3; - *w1 = hy * hx * zero_sign; - *w2 = hy * lx * zero_sign; - *w3 = ly * hx * zero_sign; - *w4 = ly * lx * zero_sign; - - return; + __asm__ volatile("sync;"); + if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { + __nramset(nram_in, channel_align, T(0)); + } else { + __memcpy_async(tmp_cyc1, tmp1, cpy_size, GDRAM2NRAM); + __memcpy_async(tmp_cyc2, tmp1 + cpy_len, cpy_size, GDRAM2NRAM); + __memcpy_async(tmp_cyc3, tmp2, cpy_size, GDRAM2NRAM); + __memcpy_async(tmp_cyc4, tmp2 + cpy_len, cpy_size, GDRAM2NRAM); + __asm__ volatile("sync;"); + __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, channel_align); + __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, channel_align); + __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, channel_align); + __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, channel_align); + __bang_sumpool(nram_in, nram_in, cyc_channel, 1, SAMPLING_NUM, 1, + SAMPLING_NUM, 1, 1); + } + __bang_add(tmp_sum, tmp_sum, nram_in, channel_align); + } + } } template -__mlu_func__ void roialignForwardKernel( - T *input, T *rois, T *output, const bool aligned, const int channels, - const int pooled_height, const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, const float spatial_scale, - const int num_rois, const int max_elements) { +__mlu_func__ void roialignForwardNpartKernel( + T *input, T *rois, T *output, T *nram_buffer, const bool aligned, + const int channels, const int pooled_height, const int pooled_width, + const int input_height, const int input_width, const int sampling_ratio, + const float spatial_scale, const int num_rois, const int max_elements) { /* * NRAM partition - * |----------------------NRAM------ -----------------| + * |----------------------NRAM------------------------| * | | * | output | * |--------------------------------------------------| @@ -92,8 +133,6 @@ __mlu_func__ void roialignForwardKernel( */ int channel_align = PAD_UP(channels, ALIGN_SIZE); - int height = 0; - int width = 0; int samp_channel_align = channel_align * SAMPLING_NUM; int samp_channel = channels * SAMPLING_NUM; @@ -103,7 +142,7 @@ __mlu_func__ void roialignForwardKernel( int offset_length; int task_length; - // the length dealt by every core and the offset of taskid + // the length dealt by every core and the offset of taskId if (taskId < rem_num) { task_length = inter_num + 1; offset_length = taskId * (inter_num + 1); @@ -112,41 +151,393 @@ __mlu_func__ void roialignForwardKernel( offset_length = rem_num * (inter_num + 1) + (taskId - rem_num) * inter_num; } - int max_size = max_elements >> 1; - T *nram_out = (T *)buffer; - T *nram_in = nram_out + max_size; - T *nram_rois = nram_in + max_elements; + int max_size = max_elements; + T *nram_out = nram_buffer; + T *nram_in = nram_out + max_size * 2; int pooled_size = pooled_height * pooled_width; - // output and roi data ptr T *top_data = output + offset_length * pooled_size * channels; T *task_rois = rois + offset_length * ROI_OFFSET; for (int roi_id = 0; roi_id < task_length; roi_id++) { // For each roi, find the corresponding feature map which it belongs to, // and compute the scaling_factor to map it to that feature map. - height = input_height; - width = input_width; T offset = aligned ? (T)0.5 : (T)0; - T *roi_id_tmp = task_rois + roi_id * ROI_OFFSET; - __bang_write_zero(nram_rois, ALIGN_SIZE); - __memcpy((void *)nram_rois, (void *)roi_id_tmp, ROI_OFFSET * sizeof(T), - GDRAM2NRAM); - int batch_id = nram_rois[0]; - T roi_xmin = nram_rois[1]; - T roi_ymin = nram_rois[2]; - T roi_xmax = nram_rois[3]; - T roi_ymax = nram_rois[4]; + int batch_id = roi_id_tmp[0]; + T roi_xmin = roi_id_tmp[1]; + T roi_ymin = roi_id_tmp[2]; + T roi_xmax = roi_id_tmp[3]; + T roi_ymax = roi_id_tmp[4]; + + roi_xmin = roi_xmin * (T)spatial_scale - offset; + roi_ymin = roi_ymin * (T)spatial_scale - offset; + roi_xmax = roi_xmax * (T)spatial_scale - offset; + roi_ymax = roi_ymax * (T)spatial_scale - offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + + if (!aligned) { + roi_width = roi_width > 1.0 ? roi_width : 1.0; + roi_height = roi_height > 1.0 ? roi_height : 1.0; + } + + T bin_size_h = roi_height / (T)pooled_height; + T bin_size_w = roi_width / (T)pooled_width; + T *offset_bottom_data = + input + batch_id * channels * input_width * input_height; + + T *tmp_sum = nram_out; + __bang_write_zero(nram_out, max_size); + + // We use roi_bin_grid to sample the grid, and perform average pooling + // inside a bin. When the grid is empty, then output zeros. + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_h)); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_w)); + T count = roi_bin_grid_h * roi_bin_grid_w; + T zero_sign_tmp = 1.0f / count; + bool is_normal_c = samp_channel_align < (max_elements * SAMPLING_NUM); + + for (int ph = 0; ph < pooled_height; ph++) { + T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid + for (int pw = 0; pw < pooled_width; pw++) { + T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid + // Bilinear interpolatation + if (is_normal_c) { + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, channel_align, channel_align, + y_pre, x_pre, zero_sign_tmp, is_normal_c, 0); + } else { + // One aligned channel data cannot be computed at one time + int cyc_num = + samp_channel / (max_elements * SAMPLING_NUM) + + (int)(samp_channel % (max_elements * SAMPLING_NUM) != 0); + int cyc_channel = max_elements; + for (int i = 0; i < cyc_num; ++i) { + int real_channel = cyc_channel < (channels - i * cyc_channel) + ? cyc_channel + : channels - i * cyc_channel; + int align_channel = (i == cyc_num - 1) + ? PAD_UP(real_channel, ALIGN_SIZE) + : cyc_channel; + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, align_channel, cyc_channel, + y_pre, x_pre, zero_sign_tmp, is_normal_c, i); + __memcpy(top_data + cyc_channel * i, tmp_sum, + real_channel * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + } + // copy output data to ddr when channel num is not aligned with 64 + if (is_normal_c) { + __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + top_data += channels; + } // loop for pw + } // loop for ph + } // loop for num_roi +} + +template +__mlu_func__ void roialignForwardHpartKernel( + T *input, T *rois, T *output, T *nram_buffer, const bool aligned, + const int channels, const int pooled_height, const int pooled_width, + const int input_height, const int input_width, const int sampling_ratio, + const float spatial_scale, const int num_rois, const int max_elements) { + int channel_align = PAD_UP(channels, ALIGN_SIZE); + int samp_channel_align = channel_align * SAMPLING_NUM; + int samp_channel = channels * SAMPLING_NUM; + int taskdim_cyc = taskDim / num_rois > 1 ? taskDim / num_rois : 1; + int roi_id = taskId / taskdim_cyc; + if (taskId >= taskdim_cyc * num_rois) { + return; + } + + // multi-core params + int inter_num = pooled_height / taskdim_cyc; + int rem_num = pooled_height % taskdim_cyc; + int offset_length; + int task_length; + + if ((taskId % taskdim_cyc) < rem_num) { + task_length = inter_num + 1; + offset_length = (taskId % taskdim_cyc) * (inter_num + 1); + } else { + task_length = inter_num; + offset_length = rem_num * (inter_num + 1) + + ((taskId % taskdim_cyc) - rem_num) * inter_num; + } + + int max_size = max_elements * 2; + T *nram_out = nram_buffer; + T *nram_in = nram_out + max_size; + + int pooled_size = pooled_height * pooled_width; + T *top_data = + output + (roi_id * pooled_size + offset_length * pooled_width) * channels; + T offset = aligned ? (T)0.5 : (T)0; + T *roi_id_tmp = rois + roi_id * ROI_OFFSET; + + int batch_id = roi_id_tmp[0]; + T roi_xmin = roi_id_tmp[1]; + T roi_ymin = roi_id_tmp[2]; + T roi_xmax = roi_id_tmp[3]; + T roi_ymax = roi_id_tmp[4]; + + roi_xmin = roi_xmin * (T)spatial_scale - offset; + roi_ymin = roi_ymin * (T)spatial_scale - offset; + roi_xmax = roi_xmax * (T)spatial_scale - offset; + roi_ymax = roi_ymax * (T)spatial_scale - offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = roi_width > 1 ? roi_width : 1.0; + roi_height = roi_height > 1 ? roi_height : 1.0; + } + + T bin_size_h = roi_height / (T)pooled_height; + T bin_size_w = roi_width / (T)pooled_width; + T *offset_bottom_data = + input + batch_id * channels * input_width * input_height; + + T *tmp_sum = nram_out; + __bang_write_zero(nram_out, max_size); + + // We use roi_bin_grid to sample the grid, and perform average pooling + // inside a bin. When the grid is empty, then output zeros. + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_h)); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_w)); + T count = roi_bin_grid_h * roi_bin_grid_w; + T zero_sign_tmp = 1.0f / count; + bool is_normal_c = samp_channel_align < (max_elements * SAMPLING_NUM); + + for (int ph = offset_length; ph < (offset_length + task_length); ph++) { + T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid + for (int pw = 0; pw < pooled_width; pw++) { + T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid + // Bilinear interpolatation + if (is_normal_c) { + bilinearInterpolate((T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, + bin_size_w, input_height, input_width, channels, + channel_align, channel_align, y_pre, x_pre, + zero_sign_tmp, is_normal_c, 0); + } else { + // One aligned channel data cannot be computed at one time + int cyc_num = samp_channel / (max_elements * SAMPLING_NUM) + + (int)(samp_channel % (max_elements * SAMPLING_NUM) != 0); + int cyc_channel = max_elements; + for (int i = 0; i < cyc_num; ++i) { + int real_channel = cyc_channel < (channels - i * cyc_channel) + ? cyc_channel + : channels - i * cyc_channel; + int align_channel = (i == cyc_num - 1) + ? PAD_UP(real_channel, ALIGN_SIZE) + : cyc_channel; + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, align_channel, cyc_channel, + y_pre, x_pre, zero_sign_tmp, is_normal_c, i); + + __memcpy(top_data + cyc_channel * i, tmp_sum, + real_channel * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + } + // copy output data to ddr when channel num is not aligned with 64 + if (is_normal_c) { + __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + top_data += channels; + } // loop for pw + } // loop for ph +} + +__mlu_global__ void MLUUnion1KernelRoialign( + const void *input, const void *rois, const int channels, const bool aligned, + const int pooled_height, const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, const float spatial_scale, + const int num_rois, const cnrtDataType_t data_type, void *output) { + size_t data_type_size = + (data_type == CNRT_FLOAT32) ? sizeof(float) : sizeof(half); + int max_elements = PAD_DOWN( + (BUFFER_SIZE / (int)data_type_size) / (ROI_OFFSET + 1), ALIGN_SIZE); + + if (taskDim < num_rois || (num_rois * pooled_height < taskDim)) { + switch (data_type) { + case CNRT_FLOAT16: { + half *nram_buffer = (half *)buffer; + roialignForwardNpartKernel( + (half *)input, (half *)rois, (half *)output, (half *)nram_buffer, + aligned, channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, max_elements); + }; break; + case CNRT_FLOAT32: { + float *nram_buffer = (float *)buffer; + roialignForwardNpartKernel( + (float *)input, (float *)rois, (float *)output, + (float *)nram_buffer, aligned, channels, pooled_height, + pooled_width, input_height, input_width, sampling_ratio, + spatial_scale, num_rois, max_elements); + }; break; + default: + break; + } + } else { + switch (data_type) { + case CNRT_FLOAT16: { + half *nram_buffer = (half *)buffer; + roialignForwardHpartKernel( + (half *)input, (half *)rois, (half *)output, (half *)nram_buffer, + aligned, channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, max_elements); + }; break; + case CNRT_FLOAT32: { + float *nram_buffer = (float *)buffer; + roialignForwardHpartKernel( + (float *)input, (float *)rois, (float *)output, + (float *)nram_buffer, aligned, channels, pooled_height, + pooled_width, input_height, input_width, sampling_ratio, + spatial_scale, num_rois, max_elements); + }; break; + default: + break; + } + } + return; +} + +template +__mlu_func__ void buSelection(T *rois_count, T *nram_temp, const int num_rois) { + for (int i = 0; i < num_rois; ++i) { + for (int j = 1; j < num_rois; ++j) { + if (rois_count[(j - 1) * 2] < rois_count[j * 2]) { + nram_temp[0] = rois_count[(j - 1) * 2]; + rois_count[(j - 1) * 2] = rois_count[j * 2]; + rois_count[j * 2] = nram_temp[0]; + nram_temp[1] = rois_count[(j - 1) * 2 + 1]; + rois_count[(j - 1) * 2 + 1] = rois_count[j * 2 + 1]; + rois_count[j * 2 + 1] = nram_temp[1]; + } + } + } +} + +template +__mlu_func__ void getPatitionList(T *h_nram, T *n_nram, T *roi_count, + int pooled_height, int num_rois, T sum, + int split_num, int &h_flag, int &n_flag) { + T avg_sum = sum / split_num; + T *h_nram_temp = h_nram; + T *n_nram_temp = n_nram; + + int n_index = 0; + T n_sum = 0; + h_flag = 0; + n_flag = 0; + int list_align = PAD_UP(ALIGN_SIZE * 5, ALIGN_SIZE); + __bang_write_zero(h_nram, list_align); + for (int i = 0; i < num_rois; i++) { + if (roi_count[2 * i] >= avg_sum) { + int h_num = std::ceil(roi_count[2 * i] / avg_sum); + int h_split = pooled_height / h_num; + int h_rem = pooled_height % h_num; + T h_sum = 0.0; + + for (int j = 0; j < h_num; j++) { + h_nram_temp[0] = i; + h_nram_temp[1] = h_sum; + h_nram_temp[2] = (j < h_rem) ? (h_split + 1) : h_split; + h_sum += h_nram_temp[2]; + h_nram_temp += 3; + n_nram_temp += 2; + h_flag++; + } + } else { + if (roi_count[2 * i] + n_sum > avg_sum) { + n_nram_temp[0] = i - n_index; + n_nram_temp[1] = i - 1; + n_sum = 0.0; + n_index = 0; + n_nram_temp += 2; + i--; + n_flag++; + } else { + n_index++; + n_sum += roi_count[2 * i]; + } + } + } + if (n_flag == 0 && n_index != 0) { + n_flag = 1; + n_nram[(h_flag + n_flag - 1) * 2] = num_rois - 1; + } + + n_nram[(h_flag + n_flag) * 2 - 1] = num_rois - 1; + + if (h_flag + n_flag > taskDim) { + getPatitionList(h_nram, n_nram, roi_count, pooled_height, num_rois, sum, + split_num - 1, h_flag, n_flag); + } + return; +} + +template +__mlu_func__ void mergeAndSplitQuantity( + T *rois, T *rois_sort, T *split_list, T *roi_count, T *nram_rois, + const bool aligned, const int pooled_height, const int pooled_width, + const int sampling_ratio, const float spatial_scale, const int num_rois, + int &h_split_num, int &n_split_num) { + /* take the coordinates out of ROIS and actually calculate the actual + * calculation size. The sorted calculation scale is partition, large scale + * is split H, small is N. + */ + T *h_tem = split_list; + T *n_tem = split_list + 3 * ALIGN_SIZE; + int num_rois_align = PAD_UP(num_rois * (ROI_OFFSET + 1), ALIGN_SIZE); + int count_align = PAD_UP(num_rois * 2, ALIGN_SIZE); + __bang_write_zero(nram_rois, num_rois_align); + T sum = 0.0; + int temp_offset = 0; + __memcpy((void *)(nram_rois + 1), (void *)rois, ROI_OFFSET * sizeof(T), + GDRAM2NRAM, (ROI_OFFSET + 1) * sizeof(T), ROI_OFFSET * sizeof(T), + (num_rois - 1)); + T *nram_temp = roi_count + count_align; + for (int roi_id = 0; roi_id < num_rois; roi_id++) { + T offset = aligned ? (T)0.5 : (T)0; - roi_xmin = roi_xmin * spatial_scale - offset; - roi_ymin = roi_ymin * spatial_scale - offset; - roi_xmax = roi_xmax * spatial_scale - offset; - roi_ymax = roi_ymax * spatial_scale - offset; + T roi_xmin = nram_rois[temp_offset + 2]; + T roi_ymin = nram_rois[temp_offset + 3]; + T roi_xmax = nram_rois[temp_offset + 4]; + T roi_ymax = nram_rois[temp_offset + 5]; - float roi_width = roi_xmax - roi_xmin; - float roi_height = roi_ymax - roi_ymin; + roi_xmin = roi_xmin * (T)spatial_scale - offset; + roi_ymin = roi_ymin * (T)spatial_scale - offset; + roi_xmax = roi_xmax * (T)spatial_scale - offset; + roi_ymax = roi_ymax * (T)spatial_scale - offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; if (!aligned) { // Force malformed ROIs to be 1x1 @@ -154,173 +545,158 @@ __mlu_func__ void roialignForwardKernel( roi_height = roi_height > 1 ? roi_height : 1.0; } - float bin_size_h = (float)roi_height / pooled_height; - float bin_size_w = (float)roi_width / pooled_width; + T bin_size_h = roi_height / (T)pooled_height; + T bin_size_w = roi_width / (T)pooled_width; + + // We use roi_bin_grid to sample the grid, and perform average pooling + // inside a bin. When the grid is empty, then output zeros. + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_h)); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_w)); + T count = roi_bin_grid_h * roi_bin_grid_w; + sum += count; + *(roi_count + 2 * roi_id) = count; + *(roi_count + 2 * roi_id + 1) = roi_id; + + *(nram_rois + roi_id * (ROI_OFFSET + 1)) = count; + temp_offset += (ROI_OFFSET + 1); + } + + buSelection(roi_count, nram_temp, num_rois); + + temp_offset = 0; + for (int i = 0; i < num_rois; i++) { + for (int j = 0; j < num_rois; j++) { + if (roi_count[2 * i] == nram_rois[j * (ROI_OFFSET + 1)]) { + rois_sort[temp_offset] = nram_rois[j * (ROI_OFFSET + 1)]; + rois_sort[temp_offset + 1] = nram_rois[j * (ROI_OFFSET + 1) + 1]; + rois_sort[temp_offset + 2] = nram_rois[j * (ROI_OFFSET + 1) + 2]; + rois_sort[temp_offset + 3] = nram_rois[j * (ROI_OFFSET + 1) + 3]; + rois_sort[temp_offset + 4] = nram_rois[j * (ROI_OFFSET + 1) + 4]; + rois_sort[temp_offset + 5] = nram_rois[j * (ROI_OFFSET + 1) + 5]; + nram_rois[j * (ROI_OFFSET + 1)] = -1.0; + break; + } + } + temp_offset += (ROI_OFFSET + 1); + } + getPatitionList(h_tem, n_tem, roi_count, pooled_height, num_rois, sum, + taskDim, h_split_num, n_split_num); +} + +template +__mlu_func__ void roialignForwardNpartKernelForBinPart( + T *input, T *rois, T *output, T *rois_sort, T *split_list, T *rois_count, + T *nram_buffer, const bool aligned, const int channels, + const int pooled_height, const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, const float spatial_scale, + const int num_rois, const int max_size) { + int channel_align = PAD_UP(channels, ALIGN_SIZE); + int samp_channel_align = channel_align * SAMPLING_NUM; + int samp_channel = channels * SAMPLING_NUM; + int max_elements = max_size * SAMPLING_NUM; + int offset_length; + int task_length; + + T *n_split_nram = split_list + 3 * ALIGN_SIZE + 2 * taskId; + T *nram_out = nram_buffer; + T *nram_in = nram_out + max_size; + T *task_rois = rois_sort + (int)n_split_nram[0] * (ROI_OFFSET + 1); + + offset_length = (int)n_split_nram[0]; + task_length = n_split_nram[1] - n_split_nram[0] + 1; + int pooled_size = pooled_height * pooled_width; + + for (int roi_id = offset_length; roi_id < offset_length + task_length; + roi_id++) { + // For each roi, find the corresponding feature map which it belongs to, + // and compute the scaling_factor to map it to that feature map. + T offset = aligned ? (T)0.5 : (T)0; + int rea_out_id = rois_count[roi_id * 2 + 1]; + T *top_data = output + rea_out_id * pooled_size * channels; + T *nram_rois = task_rois + (roi_id - offset_length) * (ROI_OFFSET + 1); + + int batch_id = nram_rois[1]; + T roi_xmin = nram_rois[2]; + T roi_ymin = nram_rois[3]; + T roi_xmax = nram_rois[4]; + T roi_ymax = nram_rois[5]; + + roi_xmin = roi_xmin * (T)spatial_scale - offset; + roi_ymin = roi_ymin * (T)spatial_scale - offset; + roi_xmax = roi_xmax * (T)spatial_scale - offset; + roi_ymax = roi_ymax * (T)spatial_scale - offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = roi_width > 1.0 ? roi_width : 1.0; + roi_height = roi_height > 1.0 ? roi_height : 1.0; + } + + T bin_size_h = roi_height / (T)pooled_height; + T bin_size_w = roi_width / (T)pooled_width; + T *offset_bottom_data = + input + batch_id * channels * input_width * input_height; - // input data ptr - T *offset_bottom_data = input + batch_id * channels * width * height; T *tmp_sum = nram_out; + __bang_write_zero(nram_in, max_elements); __bang_write_zero(nram_out, max_size); // We use roi_bin_grid to sample the grid, and perform average pooling // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = - (sampling_ratio > 0) ? sampling_ratio : __float2int_up(bin_size_h); - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : __float2int_up(bin_size_w); - float count = roi_bin_grid_h * roi_bin_grid_w; - float zero_sign_tmp = 1.0f / count; + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_h)); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_w)); + T count = roi_bin_grid_h * roi_bin_grid_w; + T zero_sign_tmp = 1.0f / count; + bool is_normal_c = samp_channel_align < max_elements; for (int ph = 0; ph < pooled_height; ph++) { - float y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid + T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid for (int pw = 0; pw < pooled_width; pw++) { - float x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid + T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid // Bilinear interpolatation - if (samp_channel_align < max_elements) { - // One aligned channel data can be computed at one time - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - float y = - (y_pre + ((iy + 0.5) * bin_size_h) / (roi_bin_grid_h)) <= 0 - ? 0 - : (y_pre + - ((iy + 0.5) * bin_size_h) / - (roi_bin_grid_h)); // center_point y - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - float x = - (x_pre + ((ix + 0.5) * bin_size_w) / (roi_bin_grid_w)) <= 0 - ? 0 - : (x_pre + - ((ix + 0.5) * bin_size_w) / - (roi_bin_grid_w)); // center_point x - T zero_sign = - (T)(x >= -1.0 && x <= width && y >= -1.0 && y <= height) * - zero_sign_tmp; - - int empty = 0; - T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - - bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, - &w3, &w4, &x_low, &x_high, &y_low, &y_high, - &empty, zero_sign); - - // load - int cpy_len = (x_high - x_low) * channels; - int cpy_size = channels * sizeof(T); - - int offset1 = (y_low * width + x_low) * channels; - int offset2 = (y_high * width + x_low) * channels; - - T *tmp1 = offset_bottom_data + offset1; - T *tmp2 = offset_bottom_data + offset2; - - T *tmp_cyc1 = nram_in; - T *tmp_cyc2 = nram_in + channel_align; - T *tmp_cyc3 = nram_in + channel_align * 2; - T *tmp_cyc4 = nram_in + channel_align * 3; - __asm__ volatile("sync;"); - if (empty == 1) { - __nramset(nram_in, channel_align, T(0)); - } else { - // load gdram to nram - __memcpy_async(tmp_cyc1, tmp1, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc2, tmp1 + cpy_len, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc3, tmp2, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc4, tmp2 + cpy_len, cpy_size, GDRAM2NRAM); - __asm__ volatile("sync;"); - // roialign_forward compute - __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, channel_align); - __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, channel_align); - __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, channel_align); - __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, channel_align); - __bang_sumpool(nram_in, nram_in, channel_align, 1, SAMPLING_NUM, - 1, SAMPLING_NUM, 1, 1); - } - __bang_add(tmp_sum, tmp_sum, nram_in, channel_align); - } - } + if (is_normal_c) { + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, channel_align, channel_align, + y_pre, x_pre, zero_sign_tmp, is_normal_c, 0); } else { // One aligned channel data cannot be computed at one time int cyc_num = samp_channel / max_elements + (int)(samp_channel % max_elements != 0); int cyc_channel = max_elements / SAMPLING_NUM; for (int i = 0; i < cyc_num; ++i) { - int real_channel = - (i == cyc_num - 1) ? channels - i * cyc_channel : cyc_channel; - int align_channel = - (i == cyc_num - 1) - ? PAD_UP((channel_align - i * cyc_channel), ALIGN_SIZE) - : cyc_channel; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - float y = - (y_pre + ((iy + 0.5) * bin_size_h) / (roi_bin_grid_h)) <= 0 - ? 0 - : (y_pre + - ((iy + 0.5) * bin_size_h) / - (roi_bin_grid_h)); // center_point y - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - float x = - (x_pre + ((ix + 0.5) * bin_size_w) / (roi_bin_grid_w)) <= 0 - ? 0 - : (x_pre + - ((ix + 0.5) * bin_size_w) / - (roi_bin_grid_w)); // center_point x - - T zero_sign = - (T)(x >= -1.0 && x <= width && y >= -1.0 && y <= height) * - zero_sign_tmp; - - int empty = 0; - T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - - bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, - &w3, &w4, &x_low, &x_high, &y_low, &y_high, - &empty, zero_sign); - - // load - int cpy_len = (x_high - x_low) * channels; - - int offset1 = (y_low * width + x_low) * channels; - int offset2 = (y_high * width + x_low) * channels; - - T *tmp1 = offset_bottom_data + offset1 + cyc_channel * i; - T *tmp2 = offset_bottom_data + offset2 + cyc_channel * i; - - T *tmp_cyc1 = nram_in; - T *tmp_cyc2 = nram_in + cyc_channel; - T *tmp_cyc3 = nram_in + cyc_channel * 2; - T *tmp_cyc4 = nram_in + cyc_channel * 3; - __asm__ volatile("sync;"); - if (empty == 1) { // exits abnormal values - __nramset(nram_in, align_channel, T(0)); - } else { - __memcpy_async(tmp_cyc1, tmp1, align_channel * sizeof(T), - GDRAM2NRAM); - __memcpy_async(tmp_cyc2, tmp1 + cpy_len, - align_channel * sizeof(T), GDRAM2NRAM); - __memcpy_async(tmp_cyc3, tmp2, align_channel * sizeof(T), - GDRAM2NRAM); - __memcpy_async(tmp_cyc4, tmp2 + cpy_len, - align_channel * sizeof(T), GDRAM2NRAM); - __asm__ volatile("sync;"); - __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, align_channel); - __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, align_channel); - __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, align_channel); - __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, align_channel); - __bang_sumpool(nram_in, nram_in, cyc_channel, 1, SAMPLING_NUM, - 1, SAMPLING_NUM, 1, 1); - } - __bang_add(tmp_sum, tmp_sum, nram_in, align_channel); - } - } + int real_channel = cyc_channel < (channels - i * cyc_channel) + ? cyc_channel + : channels - i * cyc_channel; + int align_channel = (i == cyc_num - 1) + ? PAD_UP(real_channel, ALIGN_SIZE) + : cyc_channel; + + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, align_channel, cyc_channel, + y_pre, x_pre, zero_sign_tmp, is_normal_c, i); + __memcpy(top_data + cyc_channel * i, tmp_sum, real_channel * sizeof(T), NRAM2GDRAM); __bang_write_zero(nram_out, max_size); } } // copy output data to ddr when channel num is not aligned with 64 - if (samp_channel_align < max_elements) { + if (is_normal_c) { __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); __bang_write_zero(nram_out, max_size); } @@ -330,25 +706,208 @@ __mlu_func__ void roialignForwardKernel( } // loop for num_roi } -__mlu_global__ void MLUUnion1KernelRoialign( +template +__mlu_func__ void roialignForwardHpartKernelForBinPart( + T *input, T *rois, T *output, T *rois_sort, T *split_list, T *rois_count, + T *nram_buffer, const bool aligned, const int channels, + const int pooled_height, const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, const float spatial_scale, + const int num_rois, const int max_size) { + int channel_align = PAD_UP(channels, ALIGN_SIZE); + int samp_channel_align = channel_align * SAMPLING_NUM; + int samp_channel = channels * SAMPLING_NUM; + int max_elements = max_size * SAMPLING_NUM; + + T *h_split_nram = split_list; + T *nram_out = nram_buffer; + T *nram_in = nram_out + max_size; + T *nram_rois = rois_sort + (int)h_split_nram[taskId * 3] * (ROI_OFFSET + 1); + + int offset_length = (int)h_split_nram[taskId * 3 + 1]; + int task_length = (int)h_split_nram[taskId * 3 + 2]; + int rea_out_id = (int)h_split_nram[taskId * 3]; + + rea_out_id = rois_count[rea_out_id * 2 + 1]; + int pooled_size = pooled_height * pooled_width; + T *top_data = + output + + (rea_out_id * pooled_size + offset_length * pooled_width) * channels; + + T offset = aligned ? (T)0.5 : (T)0; + + int batch_id = nram_rois[1]; + T roi_xmin = nram_rois[2]; + T roi_ymin = nram_rois[3]; + T roi_xmax = nram_rois[4]; + T roi_ymax = nram_rois[5]; + + roi_xmin = roi_xmin * (T)spatial_scale - offset; + roi_ymin = roi_ymin * (T)spatial_scale - offset; + roi_xmax = roi_xmax * (T)spatial_scale - offset; + roi_ymax = roi_ymax * (T)spatial_scale - offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = roi_width > 1 ? roi_width : 1.0; + roi_height = roi_height > 1 ? roi_height : 1.0; + } + + T bin_size_h = roi_height / (T)pooled_height; + T bin_size_w = roi_width / (T)pooled_width; + T *offset_bottom_data = + input + batch_id * channels * input_width * input_height; + + T *tmp_sum = nram_out; + __bang_write_zero(nram_in, max_elements); + __bang_write_zero(nram_out, max_size); + + // We use roi_bin_grid to sample the grid, and perform average pooling + // inside a bin. When the grid is empty, then output zeros. + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_h)); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_w)); + T count = roi_bin_grid_h * roi_bin_grid_w; + T zero_sign_tmp = 1.0f / count; + bool is_normal_c = samp_channel_align < max_elements; + + for (int ph = offset_length; ph < (offset_length + task_length); ph++) { + T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid + for (int pw = 0; pw < pooled_width; pw++) { + T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid + // Bilinear interpolatation + if (is_normal_c) { + bilinearInterpolate((T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, + bin_size_w, input_height, input_width, channels, + channel_align, channel_align, y_pre, x_pre, + zero_sign_tmp, is_normal_c, 0); + } else { + // One aligned channel data cannot be computed at one time + int cyc_num = samp_channel / max_elements + + (int)(samp_channel % max_elements != 0); + int cyc_channel = max_elements / SAMPLING_NUM; + for (int i = 0; i < cyc_num; ++i) { + int real_channel = cyc_channel < (channels - i * cyc_channel) + ? cyc_channel + : channels - i * cyc_channel; + int align_channel = (i == cyc_num - 1) + ? PAD_UP(real_channel, ALIGN_SIZE) + : cyc_channel; + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, align_channel, cyc_channel, + y_pre, x_pre, zero_sign_tmp, is_normal_c, i); + + __memcpy(top_data + cyc_channel * i, tmp_sum, + real_channel * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + } + + // copy output data to ddr when channel num is not aligned with 64 + if (is_normal_c) { + __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + top_data += channels; + } // loop for pw + } // loop for ph +} + +__mlu_global__ void MLUUnion1KernelBinPartRoialign( const void *input, const void *rois, const int channels, const bool aligned, const int pooled_height, const int pooled_width, const int input_height, const int input_width, const int sampling_ratio, const float spatial_scale, const int num_rois, const cnrtDataType_t data_type, void *output) { - int max_elements = - (data_type == CNRT_FLOAT32) ? MAX_ELEMENTS_FLOAT : MAX_ELEMENTS_HALF; + int h_split_num = 0; + int n_split_num = 0; + int num_rois_align = PAD_UP(num_rois * (ROI_OFFSET + 4), ALIGN_SIZE); + int count_align = PAD_UP(num_rois * 2, ALIGN_SIZE); + int list_align = ALIGN_SIZE * 5; + int sum_size = num_rois_align + count_align + list_align; + + if (coreId == 0x80) { + return; + } + switch (data_type) { case CNRT_FLOAT16: { - roialignForwardKernel((half *)input, (half *)rois, (half *)output, - aligned, channels, pooled_height, pooled_width, - input_height, input_width, sampling_ratio, - spatial_scale, num_rois, max_elements); + int max_channel = + PAD_DOWN((BUFFER_SIZE / sizeof(half) - sum_size) / (ROI_OFFSET + 1), + ALIGN_SIZE); + half *rois_sort = (half *)buffer; + __bang_write_zero(rois_sort, sum_size); + half *rois_count = (half *)(rois_sort + num_rois_align); + half *split_list = (half *)(rois_count + count_align); + half *nram_rois = (half *)(split_list + list_align); + mergeAndSplitQuantity((half *)rois, (half *)rois_sort, (half *)split_list, + (half *)rois_count, (half *)nram_rois, aligned, + pooled_height, pooled_width, sampling_ratio, + spatial_scale, num_rois, h_split_num, n_split_num); + half *nram_buffer = (half *)nram_rois; + __bang_write_zero(nram_rois, num_rois_align); + + if (taskId < h_split_num) { + roialignForwardHpartKernelForBinPart( + (half *)input, (half *)rois, (half *)output, (half *)rois_sort, + (half *)split_list, (half *)rois_count, (half *)nram_buffer, + aligned, channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, max_channel); + } else { + if (n_split_num > 0 && (n_split_num + h_split_num) > taskId) { + roialignForwardNpartKernelForBinPart( + (half *)input, (half *)rois, (half *)output, (half *)rois_sort, + (half *)split_list, (half *)rois_count, (half *)nram_buffer, + aligned, channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, + max_channel); + } else { + return; + } + } }; break; case CNRT_FLOAT32: { - roialignForwardKernel((float *)input, (float *)rois, (float *)output, - aligned, channels, pooled_height, pooled_width, - input_height, input_width, sampling_ratio, - spatial_scale, num_rois, max_elements); + int max_channel = + PAD_DOWN((BUFFER_SIZE / sizeof(float) - sum_size) / (ROI_OFFSET + 1), + ALIGN_SIZE); + float *rois_sort = (float *)buffer; + __bang_write_zero(rois_sort, sum_size); + float *rois_count = (float *)(rois_sort + num_rois_align); + float *split_list = (float *)(rois_count + count_align); + float *nram_rois = (float *)(split_list + list_align); + mergeAndSplitQuantity((float *)rois, (float *)rois_sort, + (float *)split_list, (float *)rois_count, + (float *)nram_rois, aligned, pooled_height, + pooled_width, sampling_ratio, spatial_scale, + num_rois, h_split_num, n_split_num); + float *nram_buffer = (float *)nram_rois; + __bang_write_zero(nram_rois, num_rois_align); + + if (taskId < h_split_num) { + roialignForwardHpartKernelForBinPart( + (float *)input, (float *)rois, (float *)output, (float *)rois_sort, + (float *)split_list, (float *)rois_count, (float *)nram_buffer, + aligned, channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, max_channel); + } else { + if (n_split_num > 0 && (n_split_num + h_split_num) > taskId) { + roialignForwardNpartKernelForBinPart( + (float *)input, (float *)rois, (float *)output, + (float *)rois_sort, (float *)split_list, (float *)rois_count, + (float *)nram_buffer, aligned, channels, pooled_height, + pooled_width, input_height, input_width, sampling_ratio, + spatial_scale, num_rois, max_channel); + } else { + return; + } + } }; break; default: break; @@ -400,14 +959,17 @@ __mlu_func__ void unionRoiAlignBp( const int wi, const int c, const int no, const int ho, const int wo, const float spatial_scale, const int sampling_ratio, const bool aligned) { int c_align = PAD_UP(c, NFU_ALIGN_SIZE / sizeof(T)); - int deal_this_core = - boxes_num / taskDim + (int)(taskId < boxes_num % taskDim); + int deal_all = boxes_num * hi * wi; + int deal_this_core = deal_all / taskDim + (int)(taskId < deal_all % taskDim); for (int i = 0; i < deal_this_core; ++i) { - int box_id = i * taskDim + taskId; - T *box = boxes + box_id * DIM_BOX; - T *grads_offset = grads + box_id * hi * wi * c; + int bhw_id = i * taskDim + taskId; + int box_id = bhw_id / (hi * wi); + int ih = (bhw_id / wi) % hi; + int iw = bhw_id % wi; + T *box = boxes + box_id * 5; int image_id = (int)box[0]; T *image_offset = grads_image + image_id * ho * wo * c; + T *grads_ = grads + box_id * hi * wi * c + ih * wi * c + iw * c; float offset = aligned ? 0.5 : 0.0; float x1 = box[1] * spatial_scale - offset; @@ -427,108 +989,113 @@ __mlu_func__ void unionRoiAlignBp( (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_height / hi); int roi_grid_w = (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / wi); - const int count = roi_grid_h * roi_grid_w; - if (c_align * sizeof(T) * BLOCK_INPUT_OUTPUT <= MAX_NRAM_SIZE) { - for (int ih = 0; ih < hi; ++ih) { - for (int iw = 0; iw < wi; ++iw) { - T *grads_ = grads_offset + ih * wi * c + iw * c; - for (int iy = 0; iy < roi_grid_h; ++iy) { - const float y = - y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; - for (int ix = 0; ix < roi_grid_w; ++ix) { - const float x = - x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; - float w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, - &x_low, &x_high, &y_low, &y_high); - if (x_low >= 0 && y_low >= 0) { - __memcpy(buffer, grads_, c * sizeof(T), GDRAM2NRAM); - __bang_mul_const((T *)buffer + c_align, (T *)buffer, - (T)(w1 / count), c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_low * wo * c + x_low * c, - (T *)buffer + c_align, c); - __bang_mul_const((T *)buffer + c_align, (T *)buffer, - (T)(w2 / count), c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_low * wo * c + x_high * c, - (T *)buffer + c_align, c); - __bang_mul_const((T *)buffer + c_align, (T *)buffer, - (T)(w3 / count), c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_high * wo * c + x_low * c, - (T *)buffer + c_align, c); - __bang_mul_const((T *)buffer + c_align, (T *)buffer, - (T)(w4 / count), c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_high * wo * c + x_high * c, - (T *)buffer + c_align, c); - } // x_low && y_low - } // ix - } // iy - } // iw - } // ih + const T count = roi_grid_h * roi_grid_w; + if (c_align * sizeof(T) * 2 <= MAX_NRAM_SIZE) { + for (int iy = 0; iy < roi_grid_h; ++iy) { + const float y = + y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; + for (int ix = 0; ix < roi_grid_w; ++ix) { + const float x = + x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; + float w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, &x_low, + &x_high, &y_low, &y_high); + if (x_low >= 0 && y_low >= 0) { + __memcpy(buffer, grads_, c * sizeof(T), GDRAM2NRAM); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, (T)w1, + c_align); + __bang_mul_const((T *)buffer + c_align, (T *)buffer + c_align, + 1 / count, c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_low * wo * c + x_low * c, + (T *)buffer + c_align, c); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, (T)w2, + c_align); + __bang_mul_const((T *)buffer + c_align, (T *)buffer + c_align, + 1 / count, c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_low * wo * c + x_high * c, + (T *)buffer + c_align, c); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, (T)w3, + c_align); + __bang_mul_const((T *)buffer + c_align, (T *)buffer + c_align, + 1 / count, c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_high * wo * c + x_low * c, + (T *)buffer + c_align, c); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, (T)w4, + c_align); + __bang_mul_const((T *)buffer + c_align, (T *)buffer + c_align, + 1 / count, c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_high * wo * c + x_high * c, + (T *)buffer + c_align, c); + } // x_low && y_low + } // ix + } // iy } else { - for (int ih = 0; ih < hi; ++ih) { - for (int iw = 0; iw < wi; ++iw) { - T *grads_ = grads_offset + ih * wi * c + iw * c; - for (int iy = 0; iy < roi_grid_h; ++iy) { - const float y = - y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; - for (int ix = 0; ix < roi_grid_w; ++ix) { - const float x = - x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; - float w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, - &x_low, &x_high, &y_low, &y_high); - if (x_low >= 0 && y_low >= 0) { - int deal_once = PAD_DOWN(MAX_NRAM_SIZE / BLOCK_INPUT_OUTPUT, - NFU_ALIGN_SIZE) / - sizeof(T); - int c_repeat = c / deal_once + (int)(c % deal_once != 0); - for (int i = 0; i < c_repeat; ++i) { - int deal_c = deal_once; - int align_c = deal_once; - if (i == c_repeat - 1) { - deal_c = c - i * deal_once; - align_c = c_align - i * deal_once; - } - __memcpy(buffer, grads_ + i * deal_once, deal_c * sizeof(T), - GDRAM2NRAM); - __bang_mul_const((T *)buffer + align_c, (T *)buffer, - (T)(w1 / count), align_c); - __bang_atomic_add( - (T *)buffer + align_c, - image_offset + y_low * wo * c + x_low * c + i * deal_once, - (T *)buffer + align_c, deal_c); - __bang_mul_const((T *)buffer + align_c, (T *)buffer, - (T)(w2 / count), align_c); - __bang_atomic_add((T *)buffer + align_c, - image_offset + y_low * wo * c + x_high * c + - i * deal_once, - (T *)buffer + align_c, deal_c); - __bang_mul_const((T *)buffer + align_c, (T *)buffer, - (T)(w3 / count), align_c); - __bang_atomic_add((T *)buffer + align_c, - image_offset + y_high * wo * c + x_low * c + - i * deal_once, - (T *)buffer + align_c, deal_c); - __bang_mul_const((T *)buffer + align_c, (T *)buffer, - (T)(w4 / count), align_c); - __bang_atomic_add((T *)buffer + align_c, - image_offset + y_high * wo * c + - x_high * c + i * deal_once, - (T *)buffer + align_c, deal_c); - } // for c_repeat - } // x_low >= 0 && y_low >= 0 - } // ix - } // iy - } // iw - } // ih - } // if c - } // i + for (int iy = 0; iy < roi_grid_h; ++iy) { + const float y = + y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; + for (int ix = 0; ix < roi_grid_w; ++ix) { + const float x = + x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; + float w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, &x_low, + &x_high, &y_low, &y_high); + if (x_low >= 0 && y_low >= 0) { + int deal_once = + PAD_DOWN(MAX_NRAM_SIZE / 2, NFU_ALIGN_SIZE) / sizeof(T); + int c_repeat = c / deal_once + (int)(c % deal_once != 0); + for (int i = 0; i < c_repeat; ++i) { + int deal_c = deal_once; + int align_c = deal_once; + if (i == c_repeat - 1) { + deal_c = c - i * deal_once; + align_c = c_align - i * deal_once; + } + __memcpy(buffer, grads_ + i * deal_once, deal_c * sizeof(T), + GDRAM2NRAM); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, (T)w1, + align_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer + align_c, + 1 / count, align_c); + __bang_atomic_add( + (T *)buffer + align_c, + image_offset + y_low * wo * c + x_low * c + i * deal_once, + (T *)buffer + align_c, deal_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, (T)w2, + align_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer + align_c, + 1 / count, align_c); + __bang_atomic_add( + (T *)buffer + align_c, + image_offset + y_low * wo * c + x_high * c + i * deal_once, + (T *)buffer + align_c, deal_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, (T)w3, + align_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer + align_c, + 1 / count, align_c); + __bang_atomic_add( + (T *)buffer + align_c, + image_offset + y_high * wo * c + x_low * c + i * deal_once, + (T *)buffer + align_c, deal_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, (T)w4, + align_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer + align_c, + 1 / count, align_c); + __bang_atomic_add( + (T *)buffer + align_c, + image_offset + y_high * wo * c + x_high * c + i * deal_once, + (T *)buffer + align_c, deal_c); + } // for c_repeat + } // x_low >= 0 && y_low >= 0 + } // ix + } // iy + } // if c + } // i } __mlu_global__ void MLUUnion1KernelRoiAlignBackward( @@ -564,9 +1131,21 @@ void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, const int input_width, const int sampling_ratio, const float spatial_scale, const int num_rois, void *output) { - forward::MLUUnion1KernelRoialign<<>>( - input, rois, channels, aligned, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, d_type, output); + // set thresholds for degradation caused by sorting + const int sort_border = 100; // threshold of num_rois + const int sort_cluster_num = 16; // threshold of cluster + + if (num_rois > sort_border || k_dim.y < sort_cluster_num) { + forward::MLUUnion1KernelRoialign<<>>( + input, rois, channels, aligned, pooled_height, pooled_width, + input_height, input_width, sampling_ratio, spatial_scale, num_rois, + d_type, output); + } else { + forward::MLUUnion1KernelBinPartRoialign<<>>( + input, rois, channels, aligned, pooled_height, pooled_width, + input_height, input_width, sampling_ratio, spatial_scale, num_rois, + d_type, output); + } } void KernelRoiAlignBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, diff --git a/mmcv/ops/csrc/pytorch/focal_loss.cpp b/mmcv/ops/csrc/pytorch/focal_loss.cpp index a0d878ff36..ea82497176 100644 --- a/mmcv/ops/csrc/pytorch/focal_loss.cpp +++ b/mmcv/ops/csrc/pytorch/focal_loss.cpp @@ -96,8 +96,10 @@ void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight, #endif #ifdef MMCV_WITH_MLU } else if (input.device().type() == at::kMLU) { - CHECK_MLU(input); - CHECK_MLU(target); + CHECK_MLU_INPUT(input); + CHECK_MLU_INPUT(target); + CHECK_MLU_INPUT(weight); + CHECK_MLU_INPUT(output); sigmoid_focal_loss_forward_mlu(input, target, weight, output, gamma, alpha); #endif } else { @@ -121,10 +123,10 @@ void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight, #endif #ifdef MMCV_WITH_MLU } else if (input.device().type() == at::kMLU) { - CHECK_MLU(input); - CHECK_MLU(target); - CHECK_MLU(weight); - CHECK_MLU(grad_input); + CHECK_MLU_INPUT(input); + CHECK_MLU_INPUT(target); + CHECK_MLU_INPUT(weight); + CHECK_MLU_INPUT(grad_input); sigmoid_focal_loss_backward_mlu(input, target, weight, grad_input, gamma, alpha); diff --git a/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp index 044e8dd011..b003e51505 100644 --- a/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp @@ -37,24 +37,38 @@ static void policyFuncForward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type, auto N = input.size(0); auto C = input.size(1); - auto nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - auto c_align_size = PAD_UP((C * input.itemsize()), NFU_ALIGN_SIZE); + const size_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); + const size_t c_align_size = PAD_UP((C * input.itemsize()), NFU_ALIGN_SIZE); const int split_target_num = 2; const int split_pipeline_num = 6; - auto scalar_size = NFU_ALIGN_SIZE; - auto weight_size = c_align_size; + const int has_weight = weight.data_ptr() != nullptr; const int target_data_width = target.scalar_type() == at::kLong ? target.itemsize() / 2 : target.itemsize(); - - // n_seg * c_align_size * split_pipeline_num + - // n_seg * target.itemsize() * split_target_num + - // weight_size + scalar_size <= nram_size - auto n_seg = (nram_size - weight_size - scalar_size) / - (c_align_size * split_pipeline_num + - target_data_width * split_target_num); - auto seg_num = (N + n_seg - 1) / n_seg; - + const int threshold_c = + PAD_DOWN((nram_size - split_target_num * sizeof(int)) / + (split_pipeline_num + has_weight), + NFU_ALIGN_SIZE) / + input.itemsize(); + + int n_seg = 1; + if (C <= threshold_c) { + int c_size = C * input.itemsize(); + int reservered_align_size = + (split_target_num + split_pipeline_num) * NFU_ALIGN_SIZE; + int wegiht_size = 0; + if (has_weight) { + c_size = c_align_size; + reservered_align_size = split_target_num * NFU_ALIGN_SIZE; + wegiht_size = c_align_size; + } + // n_seg * c_size * split_pipeline_num + n_seg * target.itemsize() * + // split_target_num + // + weight_size + reservered_align_size <= nram_size + n_seg = (nram_size - wegiht_size - reservered_align_size) / + (split_pipeline_num * c_size + split_target_num * sizeof(int32_t)); + } + auto seg_num = n_seg == 0 ? N : (N + n_seg - 1) / n_seg; auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); auto core_num = core_dim * cluster_num; @@ -103,31 +117,8 @@ void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target, CNLOG(INFO) << "weight is a empty tensor."; } - // check C - auto nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - auto input_N = input.size(0); - auto input_C = input.size(1); - const int split_target_num = 2; - const int split_pipeline_num = 6; - const int has_weight = (int)(weight.data_ptr() != nullptr); - - // target supports only INT on MLU device while it keeps LONG on host side, - // so target.itemsize() / 2 - const int target_data_width = target.scalar_type() == at::kLong - ? target.itemsize() / 2 - : target.itemsize(); - auto threshold_C = PAD_DOWN((nram_size - NFU_ALIGN_SIZE - - split_target_num * target_data_width) / - (split_pipeline_num + has_weight), - NFU_ALIGN_SIZE) / - input.itemsize(); - - TORCH_CHECK(threshold_C >= input_C, - "input.size(1) should be in the range of [0, ", threshold_C, - "]. ", "But now input.size(1) is ", input_C, "."); - + // return if zero-element if (input.numel() == 0 || target.numel() == 0 || output.numel() == 0) { - // return if zero-element return; } @@ -158,8 +149,8 @@ void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target, << k_dim.z << ">>>"; // launch kernel KernelFocalLossSigmoidForward(k_dim, k_type, queue, d_type, input_ptr, - target_ptr, weight_ptr, input_N, input_C, alpha, - gamma, output_ptr); + target_ptr, weight_ptr, input.size(0), + input.size(1), alpha, gamma, output_ptr); } void getDealNAndThresholdC(const int compute_data_bytes, diff --git a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp index af193fce33..e268199998 100644 --- a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp @@ -19,6 +19,16 @@ void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const float iou_threshold, const float offset, void *workspace_ptr, void *output_size_ptr, void *output_ptr); +int selectUnionType(uint32_t use_job, int box_num_per_core) { + // the box_num_per_core should be at least 256, otherwise the real IO + // bandwidth would be very low + while (box_num_per_core < 256 && use_job >= 4) { + box_num_per_core *= 2; + use_job /= 2; + } + return use_job; +} + Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, int offset) { // dimension parameters check @@ -42,32 +52,57 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, } int input_num_boxes = boxes.size(0); - int input_stride = boxes.size(1); + int input_stride = boxes.size(0); int max_output_boxes = boxes.size(0); - cnrtJobType_t k_type = CNRT_FUNC_TYPE_UNION1; - int core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - uint32_t dim_x = core_dim; - cnrtDim3_t k_dim = {dim_x, 1, 1}; + cnrtDataType_t data_type_input = torch_mlu::toCnrtDtype(boxes.dtype()); + cnrtDim3_t k_dim; + cnrtJobType_t k_type; + uint32_t union_number = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + uint32_t core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + uint32_t job_limit = union_number * core_dim; + uint32_t core_number = union_number * core_dim; + int box_num_per_core = (input_num_boxes + core_number - 1) / core_number; + // initiate k_type as Union1 + k_dim.x = core_dim; + k_dim.y = 1; + k_dim.z = 1; + k_type = CNRT_FUNC_TYPE_UNION1; + int use_job = selectUnionType(job_limit, box_num_per_core); + if (use_job < 4) { + k_dim.x = 1; + k_type = CNRT_FUNC_TYPE_BLOCK; + } else if (use_job == 4) { + k_dim.x = core_dim; + k_type = CNRT_FUNC_TYPE_UNION1; + } else { + k_dim.x = use_job; + k_type = (cnrtFunctionType_t)use_job; + } + // transpose boxes (n, 4) to (4, n) for better performance + auto boxes_t = boxes.transpose(0, 1); + auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes_t); + auto scores_ = torch_mlu::cnnl::ops::cnnl_contiguous(scores); auto output = at::empty({max_output_boxes}, boxes.options().dtype(at::kLong)); auto output_size = at::empty({1}, scores.options().dtype(at::kInt)); // workspace + const int info_num = 5; // x1, x2, y1, y2 and score size_t space_size = 0; if (boxes.scalar_type() == at::kHalf) { - space_size = input_num_boxes * sizeof(int16_t); + space_size = input_num_boxes * sizeof(int16_t) * info_num + sizeof(float); } else { - space_size = input_num_boxes * sizeof(float); + space_size = input_num_boxes * sizeof(float) * info_num + sizeof(float); } auto workspace = at::empty(space_size, boxes.options().dtype(at::kByte)); // get compute queue auto queue = torch_mlu::getCurQueue(); - auto boxes_impl = torch_mlu::getMluTensorImpl(boxes); + auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_); auto boxes_ptr = boxes_impl->cnnlMalloc(); - auto scores_impl = torch_mlu::getMluTensorImpl(scores); + auto scores_impl = torch_mlu::getMluTensorImpl(scores_); auto scores_ptr = scores_impl->cnnlMalloc(); auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); auto workspace_ptr = workspace_impl->cnnlMalloc(); @@ -76,20 +111,11 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, auto output_size_impl = torch_mlu::getMluTensorImpl(output_size); auto output_size_ptr = output_size_impl->cnnlMalloc(); - switch (k_type) { - default: { - TORCH_CHECK(false, "[nms_mlu]:Failed to choose kernel to launch"); - } - case CNRT_FUNC_TYPE_BLOCK: - case CNRT_FUNC_TYPE_UNION1: { - CNLOG(INFO) << "Launch Kernel MLUUnion1 or Block NMS<<>>"; - KernelNms(k_dim, k_type, queue, data_type_input, boxes_ptr, scores_ptr, - input_num_boxes, input_stride, max_output_boxes, iou_threshold, - offset, workspace_ptr, output_size_ptr, output_ptr); - }; break; - } + CNLOG(INFO) << "Launch Kernel MLUUnionX NMS<<>>"; + KernelNms(k_dim, k_type, queue, data_type_input, boxes_ptr, scores_ptr, + input_num_boxes, input_stride, max_output_boxes, iou_threshold, + offset, workspace_ptr, output_size_ptr, output_ptr); int output_num = *static_cast(output_size.cpu().data_ptr()); return output.slice(0, 0, output_num); diff --git a/mmcv/ops/csrc/pytorch/roi_align.cpp b/mmcv/ops/csrc/pytorch/roi_align.cpp index 5c3a1dd30d..a9cdcdbcdf 100644 --- a/mmcv/ops/csrc/pytorch/roi_align.cpp +++ b/mmcv/ops/csrc/pytorch/roi_align.cpp @@ -122,11 +122,11 @@ void roi_align_forward(Tensor input, Tensor rois, Tensor output, #endif #ifdef MMCV_WITH_MLU } else if (input.device().type() == at::kMLU) { - CHECK_MLU(input); - CHECK_MLU(rois); - CHECK_MLU(output); - CHECK_MLU(argmax_y); - CHECK_MLU(argmax_x); + CHECK_MLU_INPUT(input); + CHECK_MLU_INPUT(rois); + CHECK_MLU_INPUT(output); + CHECK_MLU_INPUT(argmax_y); + CHECK_MLU_INPUT(argmax_x); roi_align_forward_mlu(input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width, spatial_scale, @@ -164,11 +164,11 @@ void roi_align_backward(Tensor grad_output, Tensor rois, Tensor argmax_y, #endif #ifdef MMCV_WITH_MLU } else if (grad_output.device().type() == at::kMLU) { - CHECK_MLU(grad_output); - CHECK_MLU(rois); - CHECK_MLU(argmax_y); - CHECK_MLU(argmax_x); - CHECK_MLU(grad_input); + CHECK_MLU_INPUT(grad_output); + CHECK_MLU_INPUT(rois); + CHECK_MLU_INPUT(argmax_y); + CHECK_MLU_INPUT(argmax_x); + CHECK_MLU_INPUT(grad_input); roi_align_backward_mlu(grad_output, rois, argmax_y, argmax_x, grad_input, aligned_height, aligned_width, spatial_scale, From 5ff6c37767aeaf68fa21a2c797b42cfc2d14c01e Mon Sep 17 00:00:00 2001 From: zhouchenyang Date: Mon, 14 Mar 2022 23:49:26 +0800 Subject: [PATCH 21/30] [Improve] Improve the performance of roialign with MLU backend (#1741) --- .../csrc/common/mlu/roi_align_mlu_kernel.mlu | 1061 +++-------------- 1 file changed, 196 insertions(+), 865 deletions(-) diff --git a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu index 55df914ab0..f62554d0ef 100644 --- a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu @@ -11,907 +11,250 @@ *************************************************************************/ #include "common_mlu_helper.hpp" -__nram__ char buffer[MAX_NRAM_SIZE]; - -#define ALIGN_SIZE 64 -#define BUFFER_SIZE (MAX_NRAM_SIZE * 480 / 512) #define ROI_OFFSET 5 -#define SAMPLING_NUM 4 -#define DIM_BOX 5 -#define BLOCK_INPUT_OUTPUT 2 +__nram__ char buffer[MAX_NRAM_SIZE]; namespace forward { template -__mlu_func__ void bilinearInterpolate( - T *tmp_sum, T *nram_in, T *offset_bottom_data, const int roi_bin_grid_h, - const int roi_bin_grid_w, const T bin_size_h, const T bin_size_w, - const int input_height, const int input_width, const int channels, - const int channel_align, const int cyc_channel, T y_pre, T x_pre, - T zero_sign_tmp, bool is_normal_c, int index) { - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - T y = (y_pre + ((T)(iy + 0.5) * bin_size_h) / (T)(roi_bin_grid_h)) <= 0.0 - ? 0.0 - : (y_pre + ((T)(iy + 0.5) * bin_size_h) / (T)(roi_bin_grid_h)); - int y_low = int(y); - int y_high; - if (y_low >= input_height - 1) { - y_high = y_low = input_height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - T ly = y - y_low; - T hy = 1.0 - ly; - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - T x = (x_pre + ((ix + 0.5) * bin_size_w) / (T)(roi_bin_grid_w)) <= 0.0 - ? 0.0 - : (x_pre + ((ix + 0.5) * bin_size_w) / (roi_bin_grid_w)); - T zero_sign = - (T)(x >= -1.0 && x <= input_width && y >= -1.0 && y <= input_height) * - zero_sign_tmp; - int x_low = int(x); - int x_high; - if (x_low >= input_width - 1) { - x_high = x_low = input_width - 1; - x = T(x_low); - } else { - x_high = x_low + 1; - } - T lx = x - x_low; - T hx = 1.0 - lx; - - T w1 = hy * hx * zero_sign; - T w2 = hy * lx * zero_sign; - T w3 = ly * hx * zero_sign; - T w4 = ly * lx * zero_sign; - - // load - int cpy_len = (x_high - x_low) * channels; - int temp_size = cyc_channel < (channels - index * cyc_channel) - ? cyc_channel - : channels - index * cyc_channel; - int cpy_size = is_normal_c ? channels * sizeof(T) : temp_size * sizeof(T); - - int32_t offset1 = (y_low * input_width + x_low) * channels; - int32_t offset2 = (y_high * input_width + x_low) * channels; - - T *tmp1 = is_normal_c - ? offset_bottom_data + offset1 - : offset_bottom_data + offset1 + cyc_channel * index; - T *tmp2 = is_normal_c - ? offset_bottom_data + offset2 - : offset_bottom_data + offset2 + cyc_channel * index; - - T *tmp_cyc1 = nram_in; - T *tmp_cyc2 = nram_in + cyc_channel; - T *tmp_cyc3 = nram_in + cyc_channel * 2; - T *tmp_cyc4 = nram_in + cyc_channel * 3; - - __asm__ volatile("sync;"); - if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { - __nramset(nram_in, channel_align, T(0)); - } else { - __memcpy_async(tmp_cyc1, tmp1, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc2, tmp1 + cpy_len, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc3, tmp2, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc4, tmp2 + cpy_len, cpy_size, GDRAM2NRAM); - __asm__ volatile("sync;"); - __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, channel_align); - __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, channel_align); - __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, channel_align); - __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, channel_align); - __bang_sumpool(nram_in, nram_in, cyc_channel, 1, SAMPLING_NUM, 1, - SAMPLING_NUM, 1, 1); - } - __bang_add(tmp_sum, tmp_sum, nram_in, channel_align); - } +__mlu_func__ void bilinearInterpolate(const int input_height, + const int input_width, T y, T x, T *w1, + T *w2, T *w3, T *w4, int *x_low, + int *x_high, int *y_low, int *y_high, + bool *empty) { + // deal with cases that inverse elements are of feature map boundary + if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { + *empty = true; + return; } -} -template -__mlu_func__ void roialignForwardNpartKernel( - T *input, T *rois, T *output, T *nram_buffer, const bool aligned, - const int channels, const int pooled_height, const int pooled_width, - const int input_height, const int input_width, const int sampling_ratio, - const float spatial_scale, const int num_rois, const int max_elements) { - /* - * NRAM partition - * |----------------------NRAM------------------------| - * | | - * | output | - * |--------------------------------------------------| - * | | - * | input | - * | | - * |--------------------------------------------------| - * | rois(batch_id, x1, y1, x2, y2) | - * |--------------------------------------------------| - * - * channel data will loop inside of input_nram, when channel * size(T) > - * input_nram - */ - - int channel_align = PAD_UP(channels, ALIGN_SIZE); - int samp_channel_align = channel_align * SAMPLING_NUM; - int samp_channel = channels * SAMPLING_NUM; + if (y <= 0) y = 0; + if (x <= 0) x = 0; - // multi-core params - int inter_num = num_rois / taskDim; - int rem_num = num_rois % taskDim; - int offset_length; - int task_length; + int y_low_ = int(y); + int x_low_ = int(x); - // the length dealt by every core and the offset of taskId - if (taskId < rem_num) { - task_length = inter_num + 1; - offset_length = taskId * (inter_num + 1); + if (y_low_ >= input_height - 1) { + *y_high = y_low_ = input_height - 1; + y = (T)y_low_; } else { - task_length = inter_num; - offset_length = rem_num * (inter_num + 1) + (taskId - rem_num) * inter_num; - } - - int max_size = max_elements; - T *nram_out = nram_buffer; - T *nram_in = nram_out + max_size * 2; - - int pooled_size = pooled_height * pooled_width; - T *top_data = output + offset_length * pooled_size * channels; - T *task_rois = rois + offset_length * ROI_OFFSET; - - for (int roi_id = 0; roi_id < task_length; roi_id++) { - // For each roi, find the corresponding feature map which it belongs to, - // and compute the scaling_factor to map it to that feature map. - T offset = aligned ? (T)0.5 : (T)0; - T *roi_id_tmp = task_rois + roi_id * ROI_OFFSET; - - int batch_id = roi_id_tmp[0]; - T roi_xmin = roi_id_tmp[1]; - T roi_ymin = roi_id_tmp[2]; - T roi_xmax = roi_id_tmp[3]; - T roi_ymax = roi_id_tmp[4]; - - roi_xmin = roi_xmin * (T)spatial_scale - offset; - roi_ymin = roi_ymin * (T)spatial_scale - offset; - roi_xmax = roi_xmax * (T)spatial_scale - offset; - roi_ymax = roi_ymax * (T)spatial_scale - offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - - if (!aligned) { - roi_width = roi_width > 1.0 ? roi_width : 1.0; - roi_height = roi_height > 1.0 ? roi_height : 1.0; - } - - T bin_size_h = roi_height / (T)pooled_height; - T bin_size_w = roi_width / (T)pooled_width; - T *offset_bottom_data = - input + batch_id * channels * input_width * input_height; - - T *tmp_sum = nram_out; - __bang_write_zero(nram_out, max_size); - - // We use roi_bin_grid to sample the grid, and perform average pooling - // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_h)); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_w)); - T count = roi_bin_grid_h * roi_bin_grid_w; - T zero_sign_tmp = 1.0f / count; - bool is_normal_c = samp_channel_align < (max_elements * SAMPLING_NUM); - - for (int ph = 0; ph < pooled_height; ph++) { - T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid - for (int pw = 0; pw < pooled_width; pw++) { - T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid - // Bilinear interpolatation - if (is_normal_c) { - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, channel_align, channel_align, - y_pre, x_pre, zero_sign_tmp, is_normal_c, 0); - } else { - // One aligned channel data cannot be computed at one time - int cyc_num = - samp_channel / (max_elements * SAMPLING_NUM) + - (int)(samp_channel % (max_elements * SAMPLING_NUM) != 0); - int cyc_channel = max_elements; - for (int i = 0; i < cyc_num; ++i) { - int real_channel = cyc_channel < (channels - i * cyc_channel) - ? cyc_channel - : channels - i * cyc_channel; - int align_channel = (i == cyc_num - 1) - ? PAD_UP(real_channel, ALIGN_SIZE) - : cyc_channel; - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, align_channel, cyc_channel, - y_pre, x_pre, zero_sign_tmp, is_normal_c, i); - __memcpy(top_data + cyc_channel * i, tmp_sum, - real_channel * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - } - // copy output data to ddr when channel num is not aligned with 64 - if (is_normal_c) { - __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - top_data += channels; - } // loop for pw - } // loop for ph - } // loop for num_roi -} - -template -__mlu_func__ void roialignForwardHpartKernel( - T *input, T *rois, T *output, T *nram_buffer, const bool aligned, - const int channels, const int pooled_height, const int pooled_width, - const int input_height, const int input_width, const int sampling_ratio, - const float spatial_scale, const int num_rois, const int max_elements) { - int channel_align = PAD_UP(channels, ALIGN_SIZE); - int samp_channel_align = channel_align * SAMPLING_NUM; - int samp_channel = channels * SAMPLING_NUM; - int taskdim_cyc = taskDim / num_rois > 1 ? taskDim / num_rois : 1; - int roi_id = taskId / taskdim_cyc; - if (taskId >= taskdim_cyc * num_rois) { - return; + *y_high = y_low_ + 1; } - // multi-core params - int inter_num = pooled_height / taskdim_cyc; - int rem_num = pooled_height % taskdim_cyc; - int offset_length; - int task_length; - - if ((taskId % taskdim_cyc) < rem_num) { - task_length = inter_num + 1; - offset_length = (taskId % taskdim_cyc) * (inter_num + 1); + if (x_low_ >= input_width - 1) { + *x_high = x_low_ = input_width - 1; + x = T(x_low_); } else { - task_length = inter_num; - offset_length = rem_num * (inter_num + 1) + - ((taskId % taskdim_cyc) - rem_num) * inter_num; - } - - int max_size = max_elements * 2; - T *nram_out = nram_buffer; - T *nram_in = nram_out + max_size; - - int pooled_size = pooled_height * pooled_width; - T *top_data = - output + (roi_id * pooled_size + offset_length * pooled_width) * channels; - T offset = aligned ? (T)0.5 : (T)0; - T *roi_id_tmp = rois + roi_id * ROI_OFFSET; - - int batch_id = roi_id_tmp[0]; - T roi_xmin = roi_id_tmp[1]; - T roi_ymin = roi_id_tmp[2]; - T roi_xmax = roi_id_tmp[3]; - T roi_ymax = roi_id_tmp[4]; - - roi_xmin = roi_xmin * (T)spatial_scale - offset; - roi_ymin = roi_ymin * (T)spatial_scale - offset; - roi_xmax = roi_xmax * (T)spatial_scale - offset; - roi_ymax = roi_ymax * (T)spatial_scale - offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = roi_width > 1 ? roi_width : 1.0; - roi_height = roi_height > 1 ? roi_height : 1.0; + *x_high = x_low_ + 1; } - T bin_size_h = roi_height / (T)pooled_height; - T bin_size_w = roi_width / (T)pooled_width; - T *offset_bottom_data = - input + batch_id * channels * input_width * input_height; - - T *tmp_sum = nram_out; - __bang_write_zero(nram_out, max_size); - - // We use roi_bin_grid to sample the grid, and perform average pooling - // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_h)); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_w)); - T count = roi_bin_grid_h * roi_bin_grid_w; - T zero_sign_tmp = 1.0f / count; - bool is_normal_c = samp_channel_align < (max_elements * SAMPLING_NUM); + *y_low = y_low_; + *x_low = x_low_; - for (int ph = offset_length; ph < (offset_length + task_length); ph++) { - T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid - for (int pw = 0; pw < pooled_width; pw++) { - T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid - // Bilinear interpolatation - if (is_normal_c) { - bilinearInterpolate((T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, - bin_size_w, input_height, input_width, channels, - channel_align, channel_align, y_pre, x_pre, - zero_sign_tmp, is_normal_c, 0); - } else { - // One aligned channel data cannot be computed at one time - int cyc_num = samp_channel / (max_elements * SAMPLING_NUM) + - (int)(samp_channel % (max_elements * SAMPLING_NUM) != 0); - int cyc_channel = max_elements; - for (int i = 0; i < cyc_num; ++i) { - int real_channel = cyc_channel < (channels - i * cyc_channel) - ? cyc_channel - : channels - i * cyc_channel; - int align_channel = (i == cyc_num - 1) - ? PAD_UP(real_channel, ALIGN_SIZE) - : cyc_channel; - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, align_channel, cyc_channel, - y_pre, x_pre, zero_sign_tmp, is_normal_c, i); - - __memcpy(top_data + cyc_channel * i, tmp_sum, - real_channel * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - } - // copy output data to ddr when channel num is not aligned with 64 - if (is_normal_c) { - __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - top_data += channels; - } // loop for pw - } // loop for ph -} - -__mlu_global__ void MLUUnion1KernelRoialign( - const void *input, const void *rois, const int channels, const bool aligned, - const int pooled_height, const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, const float spatial_scale, - const int num_rois, const cnrtDataType_t data_type, void *output) { - size_t data_type_size = - (data_type == CNRT_FLOAT32) ? sizeof(float) : sizeof(half); - int max_elements = PAD_DOWN( - (BUFFER_SIZE / (int)data_type_size) / (ROI_OFFSET + 1), ALIGN_SIZE); - - if (taskDim < num_rois || (num_rois * pooled_height < taskDim)) { - switch (data_type) { - case CNRT_FLOAT16: { - half *nram_buffer = (half *)buffer; - roialignForwardNpartKernel( - (half *)input, (half *)rois, (half *)output, (half *)nram_buffer, - aligned, channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, max_elements); - }; break; - case CNRT_FLOAT32: { - float *nram_buffer = (float *)buffer; - roialignForwardNpartKernel( - (float *)input, (float *)rois, (float *)output, - (float *)nram_buffer, aligned, channels, pooled_height, - pooled_width, input_height, input_width, sampling_ratio, - spatial_scale, num_rois, max_elements); - }; break; - default: - break; - } - } else { - switch (data_type) { - case CNRT_FLOAT16: { - half *nram_buffer = (half *)buffer; - roialignForwardHpartKernel( - (half *)input, (half *)rois, (half *)output, (half *)nram_buffer, - aligned, channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, max_elements); - }; break; - case CNRT_FLOAT32: { - float *nram_buffer = (float *)buffer; - roialignForwardHpartKernel( - (float *)input, (float *)rois, (float *)output, - (float *)nram_buffer, aligned, channels, pooled_height, - pooled_width, input_height, input_width, sampling_ratio, - spatial_scale, num_rois, max_elements); - }; break; - default: - break; - } - } + T ly = y - y_low_; + T lx = x - x_low_; + T hy = 1.0 - ly; + T hx = 1.0 - lx; + *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; return; } template -__mlu_func__ void buSelection(T *rois_count, T *nram_temp, const int num_rois) { - for (int i = 0; i < num_rois; ++i) { - for (int j = 1; j < num_rois; ++j) { - if (rois_count[(j - 1) * 2] < rois_count[j * 2]) { - nram_temp[0] = rois_count[(j - 1) * 2]; - rois_count[(j - 1) * 2] = rois_count[j * 2]; - rois_count[j * 2] = nram_temp[0]; - nram_temp[1] = rois_count[(j - 1) * 2 + 1]; - rois_count[(j - 1) * 2 + 1] = rois_count[j * 2 + 1]; - rois_count[j * 2 + 1] = nram_temp[1]; - } - } - } -} - -template -__mlu_func__ void getPatitionList(T *h_nram, T *n_nram, T *roi_count, - int pooled_height, int num_rois, T sum, - int split_num, int &h_flag, int &n_flag) { - T avg_sum = sum / split_num; - T *h_nram_temp = h_nram; - T *n_nram_temp = n_nram; - - int n_index = 0; - T n_sum = 0; - h_flag = 0; - n_flag = 0; - int list_align = PAD_UP(ALIGN_SIZE * 5, ALIGN_SIZE); - __bang_write_zero(h_nram, list_align); - for (int i = 0; i < num_rois; i++) { - if (roi_count[2 * i] >= avg_sum) { - int h_num = std::ceil(roi_count[2 * i] / avg_sum); - int h_split = pooled_height / h_num; - int h_rem = pooled_height % h_num; - T h_sum = 0.0; - - for (int j = 0; j < h_num; j++) { - h_nram_temp[0] = i; - h_nram_temp[1] = h_sum; - h_nram_temp[2] = (j < h_rem) ? (h_split + 1) : h_split; - h_sum += h_nram_temp[2]; - h_nram_temp += 3; - n_nram_temp += 2; - h_flag++; - } - } else { - if (roi_count[2 * i] + n_sum > avg_sum) { - n_nram_temp[0] = i - n_index; - n_nram_temp[1] = i - 1; - n_sum = 0.0; - n_index = 0; - n_nram_temp += 2; - i--; - n_flag++; - } else { - n_index++; - n_sum += roi_count[2 * i]; - } - } - } - if (n_flag == 0 && n_index != 0) { - n_flag = 1; - n_nram[(h_flag + n_flag - 1) * 2] = num_rois - 1; - } - - n_nram[(h_flag + n_flag) * 2 - 1] = num_rois - 1; - - if (h_flag + n_flag > taskDim) { - getPatitionList(h_nram, n_nram, roi_count, pooled_height, num_rois, sum, - split_num - 1, h_flag, n_flag); - } - return; +__mlu_func__ void computeChannel(T *input_core, T *nram_in, T *output_core, + T *nram_out, const int roi_bin_grid_h, + const int roi_bin_grid_w, const T roi_start_h, + const T roi_start_w, const int ph, + const int pw, const T bin_size_h, + const T bin_size_w, const float count, + const int input_height, const int input_width, + const int channels, const int cyc_num, + const int max_elements) { + int cyc_channel = max_elements; + + for (int i = 0; i < cyc_num; i++) { + int real_channel = + (i == cyc_num - 1) ? channels - i * cyc_channel : cyc_channel; + int align_channel = PAD_UP(real_channel, NFU_ALIGN_SIZE / sizeof(T)); + __bang_write_zero(nram_out, align_channel); + uint32_t real_size = real_channel * sizeof(T); + + int iy, ix; + for (iy = 0; iy < roi_bin_grid_h; iy++) { + // 1. compute the coordinates of the y axis in the current roi_bin_grid_h + T y = roi_start_h + ph * bin_size_h + + (T)(iy + 0.5) * bin_size_h / (T)(roi_bin_grid_h); + for (ix = 0; ix < roi_bin_grid_w; ix++) { + // 2. compute the coordinates of the x axis in the current + // roi_bin_grid_w + T x = roi_start_w + pw * bin_size_w + + (T)(ix + 0.5) * bin_size_w / (T)(roi_bin_grid_w); + + // 3. compute the four weights (w1, w2, w3 and w4), the height (y_low + // and y_high) and weight (x_low and x_high) of input feature map in + // the current roi bin grid, and the flag (empty) which shows if x, y + // are out of input feature map ranges + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bool empty = false; + + bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, &w3, &w4, + &x_low, &x_high, &y_low, &y_high, &empty); + + // 4. compute interpolation of the current roi bin grid + // tmp_cyc1, temp_cyc2, tmp_cyc3 and tmp_cyc4 store the input values + // to compute the interpolation, and then reused to compute + // the argmax_x and argmax_y. + T *tmp_cyc1 = nram_in + cyc_channel; + T *tmp_cyc2 = nram_in + cyc_channel * 2; + T *tmp_cyc3 = nram_in + cyc_channel * 3; + T *tmp_cyc4 = nram_in + cyc_channel * 4; + + if (empty) { // exits abnormal values + __bang_write_zero(nram_in, align_channel); + } else { + __bang_write_zero(nram_in, align_channel); + uint32_t offset1 = (y_low * input_width + x_low) * channels; + uint32_t offset2 = (y_low * input_width + x_high) * channels; + uint32_t offset3 = (y_high * input_width + x_low) * channels; + uint32_t offset4 = (y_high * input_width + x_high) * channels; + T *input1 = (T *)input_core + offset1 + i * cyc_channel; + T *input2 = (T *)input_core + offset2 + i * cyc_channel; + T *input3 = (T *)input_core + offset3 + i * cyc_channel; + T *input4 = (T *)input_core + offset4 + i * cyc_channel; + + // load the four pixels (p1, p2, p3 and p4) of input feature map to + // compute interpolation + __memcpy(tmp_cyc1, input1, real_size, GDRAM2NRAM); + __memcpy(tmp_cyc2, input2, real_size, GDRAM2NRAM); + __memcpy(tmp_cyc3, input3, real_size, GDRAM2NRAM); + __memcpy(tmp_cyc4, input4, real_size, GDRAM2NRAM); + + // interpolation value = w1 * p1 + w2 * p2 + w3 * p3 + w4 * p4 + __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, align_channel); + __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, align_channel); + __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, align_channel); + __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, align_channel); + + __bang_add(nram_in, tmp_cyc1, nram_in, align_channel); + __bang_add(nram_in, tmp_cyc2, nram_in, align_channel); + __bang_add(nram_in, tmp_cyc3, nram_in, align_channel); + __bang_add(nram_in, tmp_cyc4, nram_in, align_channel); + } + // 5. compute sum value and corresponding coordinates of x axis and y + // axis. Update the sum value. + __bang_add(nram_out, nram_in, nram_out, align_channel); + } // loop_roi_grid_w + } // loop_roi_grid_h + T count_value = (T)(1.0 / count); + __bang_mul_const(nram_out, nram_out, count_value, align_channel); + __memcpy(output_core + i * cyc_channel, nram_out, real_size, NRAM2GDRAM); + } // loop_cyc_num } template -__mlu_func__ void mergeAndSplitQuantity( - T *rois, T *rois_sort, T *split_list, T *roi_count, T *nram_rois, - const bool aligned, const int pooled_height, const int pooled_width, - const int sampling_ratio, const float spatial_scale, const int num_rois, - int &h_split_num, int &n_split_num) { - /* take the coordinates out of ROIS and actually calculate the actual - * calculation size. The sorted calculation scale is partition, large scale - * is split H, small is N. - */ - T *h_tem = split_list; - T *n_tem = split_list + 3 * ALIGN_SIZE; - int num_rois_align = PAD_UP(num_rois * (ROI_OFFSET + 1), ALIGN_SIZE); - int count_align = PAD_UP(num_rois * 2, ALIGN_SIZE); - __bang_write_zero(nram_rois, num_rois_align); - T sum = 0.0; - int temp_offset = 0; - __memcpy((void *)(nram_rois + 1), (void *)rois, ROI_OFFSET * sizeof(T), - GDRAM2NRAM, (ROI_OFFSET + 1) * sizeof(T), ROI_OFFSET * sizeof(T), - (num_rois - 1)); - T *nram_temp = roi_count + count_align; - for (int roi_id = 0; roi_id < num_rois; roi_id++) { - T offset = aligned ? (T)0.5 : (T)0; - - T roi_xmin = nram_rois[temp_offset + 2]; - T roi_ymin = nram_rois[temp_offset + 3]; - T roi_xmax = nram_rois[temp_offset + 4]; - T roi_ymax = nram_rois[temp_offset + 5]; - - roi_xmin = roi_xmin * (T)spatial_scale - offset; - roi_ymin = roi_ymin * (T)spatial_scale - offset; - roi_xmax = roi_xmax * (T)spatial_scale - offset; - roi_ymax = roi_ymax * (T)spatial_scale - offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = roi_width > 1 ? roi_width : 1.0; - roi_height = roi_height > 1 ? roi_height : 1.0; +__mlu_func__ void roialignForwardAvg( + T *input, T *rois, T *output, const bool aligned, const int channels, + const int pooled_height, const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, const T spatial_scale, + const int num_rois) { + // find limit for channel, the nram space is divided to 6 parts that are + // input, 4 weights to compute the interpolation (w1, w2, w3, w4), output + + // max_elements : 300 : float datatype : 27296, half datatype : 54592 + // max_elements : 200 : float datatype : 16384, half datatype : 32768 + int max_elements = (PAD_DOWN(MAX_NRAM_SIZE / 6, NFU_ALIGN_SIZE)) / sizeof(T); + int cyc_num = channels / max_elements + (int)(channels % max_elements != 0); + T offset = aligned ? (T)0.5 : (T)0.0; + int task_num = num_rois * pooled_height * pooled_width; + T *nram_out = (T *)buffer; + T *nram_in = nram_out + max_elements; + if (task_num < taskDim) { + if (taskId >= task_num) { + return; } - - T bin_size_h = roi_height / (T)pooled_height; - T bin_size_w = roi_width / (T)pooled_width; - - // We use roi_bin_grid to sample the grid, and perform average pooling - // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_h)); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_w)); - T count = roi_bin_grid_h * roi_bin_grid_w; - sum += count; - *(roi_count + 2 * roi_id) = count; - *(roi_count + 2 * roi_id + 1) = roi_id; - - *(nram_rois + roi_id * (ROI_OFFSET + 1)) = count; - temp_offset += (ROI_OFFSET + 1); } - buSelection(roi_count, nram_temp, num_rois); - - temp_offset = 0; - for (int i = 0; i < num_rois; i++) { - for (int j = 0; j < num_rois; j++) { - if (roi_count[2 * i] == nram_rois[j * (ROI_OFFSET + 1)]) { - rois_sort[temp_offset] = nram_rois[j * (ROI_OFFSET + 1)]; - rois_sort[temp_offset + 1] = nram_rois[j * (ROI_OFFSET + 1) + 1]; - rois_sort[temp_offset + 2] = nram_rois[j * (ROI_OFFSET + 1) + 2]; - rois_sort[temp_offset + 3] = nram_rois[j * (ROI_OFFSET + 1) + 3]; - rois_sort[temp_offset + 4] = nram_rois[j * (ROI_OFFSET + 1) + 4]; - rois_sort[temp_offset + 5] = nram_rois[j * (ROI_OFFSET + 1) + 5]; - nram_rois[j * (ROI_OFFSET + 1)] = -1.0; - break; - } + for (int bin_idx = taskId; bin_idx < task_num; bin_idx = bin_idx + taskDim) { + if (bin_idx >= task_num) { + return; } - temp_offset += (ROI_OFFSET + 1); - } - getPatitionList(h_tem, n_tem, roi_count, pooled_height, num_rois, sum, - taskDim, h_split_num, n_split_num); -} - -template -__mlu_func__ void roialignForwardNpartKernelForBinPart( - T *input, T *rois, T *output, T *rois_sort, T *split_list, T *rois_count, - T *nram_buffer, const bool aligned, const int channels, - const int pooled_height, const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, const float spatial_scale, - const int num_rois, const int max_size) { - int channel_align = PAD_UP(channels, ALIGN_SIZE); - int samp_channel_align = channel_align * SAMPLING_NUM; - int samp_channel = channels * SAMPLING_NUM; - int max_elements = max_size * SAMPLING_NUM; - int offset_length; - int task_length; - - T *n_split_nram = split_list + 3 * ALIGN_SIZE + 2 * taskId; - T *nram_out = nram_buffer; - T *nram_in = nram_out + max_size; - T *task_rois = rois_sort + (int)n_split_nram[0] * (ROI_OFFSET + 1); - - offset_length = (int)n_split_nram[0]; - task_length = n_split_nram[1] - n_split_nram[0] + 1; - int pooled_size = pooled_height * pooled_width; - - for (int roi_id = offset_length; roi_id < offset_length + task_length; - roi_id++) { - // For each roi, find the corresponding feature map which it belongs to, - // and compute the scaling_factor to map it to that feature map. - T offset = aligned ? (T)0.5 : (T)0; - int rea_out_id = rois_count[roi_id * 2 + 1]; - T *top_data = output + rea_out_id * pooled_size * channels; - T *nram_rois = task_rois + (roi_id - offset_length) * (ROI_OFFSET + 1); - int batch_id = nram_rois[1]; - T roi_xmin = nram_rois[2]; - T roi_ymin = nram_rois[3]; - T roi_xmax = nram_rois[4]; - T roi_ymax = nram_rois[5]; - - roi_xmin = roi_xmin * (T)spatial_scale - offset; - roi_ymin = roi_ymin * (T)spatial_scale - offset; - roi_xmax = roi_xmax * (T)spatial_scale - offset; - roi_ymax = roi_ymax * (T)spatial_scale - offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; + // (n,ph.pw) is a c in the pooled output + int pw = bin_idx % pooled_width; + int ph = (bin_idx / pooled_width) % pooled_height; + int n = bin_idx / pooled_width / pooled_height; + + T *roi_id_tmp = rois + n * ROI_OFFSET; + // 1. compute width and height of roi region. + int batch_idx = (int)roi_id_tmp[0]; + T roi_x1 = roi_id_tmp[1]; + T roi_y1 = roi_id_tmp[2]; + T roi_x2 = roi_id_tmp[3]; + T roi_y2 = roi_id_tmp[4]; + T roi_start_w = roi_x1 * spatial_scale - offset; + T roi_start_h = roi_y1 * spatial_scale - offset; + T roi_end_w = roi_x2 * spatial_scale - offset; + T roi_end_h = roi_y2 * spatial_scale - offset; + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = roi_width > 1.0 ? roi_width : 1.0; - roi_height = roi_height > 1.0 ? roi_height : 1.0; + roi_width = roi_width > (T)(1.0) ? roi_width : (T)(1.0); + roi_height = roi_height > (T)(1.0) ? roi_height : (T)(1.0); } - T bin_size_h = roi_height / (T)pooled_height; - T bin_size_w = roi_width / (T)pooled_width; - T *offset_bottom_data = - input + batch_id * channels * input_width * input_height; - - T *tmp_sum = nram_out; - __bang_write_zero(nram_in, max_elements); - __bang_write_zero(nram_out, max_size); - - // We use roi_bin_grid to sample the grid, and perform average pooling - // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_h)); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_w)); - T count = roi_bin_grid_h * roi_bin_grid_w; - T zero_sign_tmp = 1.0f / count; - bool is_normal_c = samp_channel_align < max_elements; - - for (int ph = 0; ph < pooled_height; ph++) { - T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid - for (int pw = 0; pw < pooled_width; pw++) { - T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid - // Bilinear interpolatation - if (is_normal_c) { - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, channel_align, channel_align, - y_pre, x_pre, zero_sign_tmp, is_normal_c, 0); - } else { - // One aligned channel data cannot be computed at one time - int cyc_num = samp_channel / max_elements + - (int)(samp_channel % max_elements != 0); - int cyc_channel = max_elements / SAMPLING_NUM; - for (int i = 0; i < cyc_num; ++i) { - int real_channel = cyc_channel < (channels - i * cyc_channel) - ? cyc_channel - : channels - i * cyc_channel; - int align_channel = (i == cyc_num - 1) - ? PAD_UP(real_channel, ALIGN_SIZE) - : cyc_channel; - - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, align_channel, cyc_channel, - y_pre, x_pre, zero_sign_tmp, is_normal_c, i); - - __memcpy(top_data + cyc_channel * i, tmp_sum, - real_channel * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - } - // copy output data to ddr when channel num is not aligned with 64 - if (is_normal_c) { - __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - top_data += channels; - } // loop for pw - } // loop for ph - } // loop for num_roi -} - -template -__mlu_func__ void roialignForwardHpartKernelForBinPart( - T *input, T *rois, T *output, T *rois_sort, T *split_list, T *rois_count, - T *nram_buffer, const bool aligned, const int channels, - const int pooled_height, const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, const float spatial_scale, - const int num_rois, const int max_size) { - int channel_align = PAD_UP(channels, ALIGN_SIZE); - int samp_channel_align = channel_align * SAMPLING_NUM; - int samp_channel = channels * SAMPLING_NUM; - int max_elements = max_size * SAMPLING_NUM; - - T *h_split_nram = split_list; - T *nram_out = nram_buffer; - T *nram_in = nram_out + max_size; - T *nram_rois = rois_sort + (int)h_split_nram[taskId * 3] * (ROI_OFFSET + 1); - - int offset_length = (int)h_split_nram[taskId * 3 + 1]; - int task_length = (int)h_split_nram[taskId * 3 + 2]; - int rea_out_id = (int)h_split_nram[taskId * 3]; - - rea_out_id = rois_count[rea_out_id * 2 + 1]; - int pooled_size = pooled_height * pooled_width; - T *top_data = - output + - (rea_out_id * pooled_size + offset_length * pooled_width) * channels; - - T offset = aligned ? (T)0.5 : (T)0; - - int batch_id = nram_rois[1]; - T roi_xmin = nram_rois[2]; - T roi_ymin = nram_rois[3]; - T roi_xmax = nram_rois[4]; - T roi_ymax = nram_rois[5]; - - roi_xmin = roi_xmin * (T)spatial_scale - offset; - roi_ymin = roi_ymin * (T)spatial_scale - offset; - roi_xmax = roi_xmax * (T)spatial_scale - offset; - roi_ymax = roi_ymax * (T)spatial_scale - offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = roi_width > 1 ? roi_width : 1.0; - roi_height = roi_height > 1 ? roi_height : 1.0; + // 2. compute float-type width and height of roi bin region. + T bin_size_w = (T)roi_width / (T)pooled_width; + T bin_size_h = (T)roi_height / (T)pooled_height; + + // 3. compute int-type width and height of roi bin region. + int roi_bin_grid_h, roi_bin_grid_w; + roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : int(ceilf(roi_height / pooled_height)); + roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : int(ceilf(roi_width / pooled_width)); + float count = (float)((roi_bin_grid_h * roi_bin_grid_w) > 1 + ? roi_bin_grid_h * roi_bin_grid_w + : 1.0); + T *input_core = input + batch_idx * channels * input_width * input_height; + T *output_core = output + bin_idx * channels; + // 4. compute avg value and corresponding coordinates of x axis and y axis. + computeChannel(input_core, nram_in, output_core, nram_out, roi_bin_grid_h, + roi_bin_grid_w, roi_start_h, roi_start_w, ph, pw, bin_size_h, + bin_size_w, count, input_height, input_width, channels, + cyc_num, max_elements); } - - T bin_size_h = roi_height / (T)pooled_height; - T bin_size_w = roi_width / (T)pooled_width; - T *offset_bottom_data = - input + batch_id * channels * input_width * input_height; - - T *tmp_sum = nram_out; - __bang_write_zero(nram_in, max_elements); - __bang_write_zero(nram_out, max_size); - - // We use roi_bin_grid to sample the grid, and perform average pooling - // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_h)); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_w)); - T count = roi_bin_grid_h * roi_bin_grid_w; - T zero_sign_tmp = 1.0f / count; - bool is_normal_c = samp_channel_align < max_elements; - - for (int ph = offset_length; ph < (offset_length + task_length); ph++) { - T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid - for (int pw = 0; pw < pooled_width; pw++) { - T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid - // Bilinear interpolatation - if (is_normal_c) { - bilinearInterpolate((T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, - bin_size_w, input_height, input_width, channels, - channel_align, channel_align, y_pre, x_pre, - zero_sign_tmp, is_normal_c, 0); - } else { - // One aligned channel data cannot be computed at one time - int cyc_num = samp_channel / max_elements + - (int)(samp_channel % max_elements != 0); - int cyc_channel = max_elements / SAMPLING_NUM; - for (int i = 0; i < cyc_num; ++i) { - int real_channel = cyc_channel < (channels - i * cyc_channel) - ? cyc_channel - : channels - i * cyc_channel; - int align_channel = (i == cyc_num - 1) - ? PAD_UP(real_channel, ALIGN_SIZE) - : cyc_channel; - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, align_channel, cyc_channel, - y_pre, x_pre, zero_sign_tmp, is_normal_c, i); - - __memcpy(top_data + cyc_channel * i, tmp_sum, - real_channel * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - } - - // copy output data to ddr when channel num is not aligned with 64 - if (is_normal_c) { - __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - top_data += channels; - } // loop for pw - } // loop for ph } -__mlu_global__ void MLUUnion1KernelBinPartRoialign( +__mlu_global__ void MLUUnion1KernelRoiAlignAvg( const void *input, const void *rois, const int channels, const bool aligned, const int pooled_height, const int pooled_width, const int input_height, const int input_width, const int sampling_ratio, const float spatial_scale, const int num_rois, const cnrtDataType_t data_type, void *output) { - int h_split_num = 0; - int n_split_num = 0; - int num_rois_align = PAD_UP(num_rois * (ROI_OFFSET + 4), ALIGN_SIZE); - int count_align = PAD_UP(num_rois * 2, ALIGN_SIZE); - int list_align = ALIGN_SIZE * 5; - int sum_size = num_rois_align + count_align + list_align; - + // make sure that memcore is not used if (coreId == 0x80) { return; } switch (data_type) { case CNRT_FLOAT16: { - int max_channel = - PAD_DOWN((BUFFER_SIZE / sizeof(half) - sum_size) / (ROI_OFFSET + 1), - ALIGN_SIZE); - half *rois_sort = (half *)buffer; - __bang_write_zero(rois_sort, sum_size); - half *rois_count = (half *)(rois_sort + num_rois_align); - half *split_list = (half *)(rois_count + count_align); - half *nram_rois = (half *)(split_list + list_align); - mergeAndSplitQuantity((half *)rois, (half *)rois_sort, (half *)split_list, - (half *)rois_count, (half *)nram_rois, aligned, - pooled_height, pooled_width, sampling_ratio, - spatial_scale, num_rois, h_split_num, n_split_num); - half *nram_buffer = (half *)nram_rois; - __bang_write_zero(nram_rois, num_rois_align); - - if (taskId < h_split_num) { - roialignForwardHpartKernelForBinPart( - (half *)input, (half *)rois, (half *)output, (half *)rois_sort, - (half *)split_list, (half *)rois_count, (half *)nram_buffer, - aligned, channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, max_channel); - } else { - if (n_split_num > 0 && (n_split_num + h_split_num) > taskId) { - roialignForwardNpartKernelForBinPart( - (half *)input, (half *)rois, (half *)output, (half *)rois_sort, - (half *)split_list, (half *)rois_count, (half *)nram_buffer, - aligned, channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, - max_channel); - } else { - return; - } - } + roialignForwardAvg((half *)input, (half *)rois, (half *)output, aligned, + channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, + (half)spatial_scale, num_rois); }; break; case CNRT_FLOAT32: { - int max_channel = - PAD_DOWN((BUFFER_SIZE / sizeof(float) - sum_size) / (ROI_OFFSET + 1), - ALIGN_SIZE); - float *rois_sort = (float *)buffer; - __bang_write_zero(rois_sort, sum_size); - float *rois_count = (float *)(rois_sort + num_rois_align); - float *split_list = (float *)(rois_count + count_align); - float *nram_rois = (float *)(split_list + list_align); - mergeAndSplitQuantity((float *)rois, (float *)rois_sort, - (float *)split_list, (float *)rois_count, - (float *)nram_rois, aligned, pooled_height, - pooled_width, sampling_ratio, spatial_scale, - num_rois, h_split_num, n_split_num); - float *nram_buffer = (float *)nram_rois; - __bang_write_zero(nram_rois, num_rois_align); - - if (taskId < h_split_num) { - roialignForwardHpartKernelForBinPart( - (float *)input, (float *)rois, (float *)output, (float *)rois_sort, - (float *)split_list, (float *)rois_count, (float *)nram_buffer, - aligned, channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, max_channel); - } else { - if (n_split_num > 0 && (n_split_num + h_split_num) > taskId) { - roialignForwardNpartKernelForBinPart( - (float *)input, (float *)rois, (float *)output, - (float *)rois_sort, (float *)split_list, (float *)rois_count, - (float *)nram_buffer, aligned, channels, pooled_height, - pooled_width, input_height, input_width, sampling_ratio, - spatial_scale, num_rois, max_channel); - } else { - return; - } - } + roialignForwardAvg((float *)input, (float *)rois, (float *)output, + aligned, channels, pooled_height, pooled_width, + input_height, input_width, sampling_ratio, + (float)spatial_scale, num_rois); }; break; default: break; } + return; } } // namespace forward @@ -1131,21 +474,9 @@ void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, const int input_width, const int sampling_ratio, const float spatial_scale, const int num_rois, void *output) { - // set thresholds for degradation caused by sorting - const int sort_border = 100; // threshold of num_rois - const int sort_cluster_num = 16; // threshold of cluster - - if (num_rois > sort_border || k_dim.y < sort_cluster_num) { - forward::MLUUnion1KernelRoialign<<>>( - input, rois, channels, aligned, pooled_height, pooled_width, - input_height, input_width, sampling_ratio, spatial_scale, num_rois, - d_type, output); - } else { - forward::MLUUnion1KernelBinPartRoialign<<>>( - input, rois, channels, aligned, pooled_height, pooled_width, - input_height, input_width, sampling_ratio, spatial_scale, num_rois, - d_type, output); - } + forward::MLUUnion1KernelRoiAlignAvg<<>>( + input, rois, channels, aligned, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, d_type, output); } void KernelRoiAlignBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, From 8fe8d3fa3802170bf182adfe947e295c7dc8b70f Mon Sep 17 00:00:00 2001 From: Mrxiaofei <36697723+Mrxiaofei@users.noreply.github.com> Date: Tue, 29 Mar 2022 11:29:46 +0800 Subject: [PATCH 22/30] [Feature] Support tin_shift with cambricon MLU backend (#1696) * [Feature] Support tin_shift with cambricon MLU backend * [fix] Add the assertion of batch_size in tin_shift.py * [fix] fix the param check of tin_shift in cambricon code * [fix] Fix lint failure. * [fix] Fix source file lint failure. * Update mmcv/ops/tin_shift.py [Refactor] Modify the code in mmcv/ops/tin_shift.py. Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: budefei Co-authored-by: budefei Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- .../csrc/common/mlu/tin_shift_mlu_kernel.mlu | 307 ++++++++++++++++++ mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp | 185 +++++++++++ mmcv/ops/csrc/pytorch/tin_shift.cpp | 34 ++ mmcv/ops/tin_shift.py | 4 + tests/test_ops/test_tin_shift.py | 55 +++- 5 files changed, 568 insertions(+), 17 deletions(-) create mode 100644 mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu create mode 100644 mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp mode change 100644 => 100755 mmcv/ops/tin_shift.py mode change 100644 => 100755 tests/test_ops/test_tin_shift.py diff --git a/mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu new file mode 100644 index 0000000000..7cb6df0e5d --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu @@ -0,0 +1,307 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "common_mlu_helper.hpp" + +__nram__ char data_nram[MAX_NRAM_SIZE]; + +template +__mlu_func__ void mluMultiKernelTinShift( + const T *input, const int *shifts, T *output, const int batch_size, + const int time_size, const int channel_size, const int hw_size, + const int group_size, const int group_channel) { + for (int cur_channel_index = taskId; + cur_channel_index < batch_size * channel_size; + cur_channel_index += taskDim) { + int n_index = cur_channel_index / channel_size; + int group_id = cur_channel_index % channel_size / group_channel; + int t_shift = shifts[n_index * group_size + group_id]; + int index = cur_channel_index % channel_size * hw_size + + n_index * time_size * channel_size * hw_size; + __nramset(data_nram, MAX_NRAM_SIZE, (char)0); + __asm__ volatile("sync;"); + if (abs(t_shift) >= time_size) { + __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, + channel_size * hw_size * sizeof(T), hw_size * sizeof(T), + time_size - 1); + } else { + if (t_shift > 0) { + __memcpy(data_nram + t_shift * hw_size * sizeof(T), input + index, + hw_size * sizeof(T), GDRAM2NRAM, hw_size * sizeof(T), + channel_size * hw_size * sizeof(T), time_size - 1 - t_shift); + __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, + channel_size * hw_size * sizeof(T), hw_size * sizeof(T), + time_size - 1); + } else { + __memcpy(data_nram, input + (index - t_shift * channel_size * hw_size), + hw_size * sizeof(T), GDRAM2NRAM, hw_size * sizeof(T), + channel_size * hw_size * sizeof(T), time_size - 1 + t_shift); + __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, + channel_size * hw_size * sizeof(T), hw_size * sizeof(T), + time_size - 1); + } + } + __asm__ volatile("sync;"); + } +} + +template +__mlu_func__ void mluHwSplit(const T *input, const int t_shift, + const int time_size, const int hw_size, + const int channel_size, const int index, + const int cur_sequence_index, + const int max_length_per_core, T *output) { + for (int cur_index = index; cur_index < index + hw_size; + cur_index += max_length_per_core) { + int memcpy_size = max_length_per_core; + if (cur_index + max_length_per_core > index + hw_size) { + memcpy_size = index + hw_size - cur_index; + } + if (cur_sequence_index - t_shift < 0 || + cur_sequence_index - t_shift >= time_size) { + __memcpy(output + cur_index, data_nram, memcpy_size * sizeof(T), + NRAM2GDRAM); + } else { + __memcpy(data_nram, input + cur_index - t_shift * channel_size * hw_size, + memcpy_size * sizeof(T), GDRAM2NRAM); + __memcpy(output + cur_index, data_nram, memcpy_size * sizeof(T), + NRAM2GDRAM); + } + __asm__ volatile("sync;"); + } +} + +template +__mlu_func__ void mluMultiKernelTinShiftSplitSequence( + const T *input, const int *shifts, T *output, const int batch_size, + const int time_size, const int channel_size, const int hw_size, + const int group_size, const int group_channel, + const int max_number_hw_per_core, const int max_length_per_core) { + const int tmp_max_number_hw_per_core = + max_number_hw_per_core > 0 ? max_number_hw_per_core : 1; + const int loop_time = time_size / tmp_max_number_hw_per_core + + ((time_size % tmp_max_number_hw_per_core) > 0 ? 1 : 0); + int segmentime_size = tmp_max_number_hw_per_core; + int res_segment = time_size % tmp_max_number_hw_per_core; + + for (int cur_segment_index = taskId; + cur_segment_index < loop_time * batch_size * channel_size; + cur_segment_index += taskDim) { + int n_index = cur_segment_index / loop_time / channel_size; + int group_id = cur_segment_index / loop_time % channel_size / group_channel; + int t_shift = shifts[n_index * group_size + group_id]; + int index = n_index * time_size * channel_size * hw_size + + (cur_segment_index / loop_time % channel_size) * hw_size + + cur_segment_index % loop_time * segmentime_size * hw_size * + channel_size; + char *dst_gdram2nram = data_nram; + const T *src_gdram2nram = input + index; + int count_gdram2nram = -1; + int count_nram2gdram = -1; + int next_sequence_index = + index / hw_size / channel_size % time_size + segmentime_size; + int cur_sequence_index = index / hw_size / channel_size % time_size; + __nramset(data_nram, MAX_NRAM_SIZE, (char)0); + __asm__ volatile("sync;"); + if (max_number_hw_per_core == 0) { + mluHwSplit(input, t_shift, time_size, hw_size, channel_size, index, + cur_sequence_index, max_length_per_core, output); + continue; + } + if (abs(t_shift) >= time_size) { + if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) { + __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, + channel_size * hw_size * sizeof(T), hw_size * sizeof(T), + res_segment - 1); + } else { + __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, + channel_size * hw_size * sizeof(T), hw_size * sizeof(T), + segmentime_size - 1); + } + continue; + } + if (t_shift == 0) { + if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) { + dst_gdram2nram = data_nram; + src_gdram2nram = input + index; + count_gdram2nram = res_segment - 1; + count_nram2gdram = res_segment - 1; + } else { + dst_gdram2nram = data_nram; + src_gdram2nram = input + index; + count_gdram2nram = segmentime_size - 1; + count_nram2gdram = segmentime_size - 1; + } + } else if (t_shift > 0) { + int first_index_cur_channel = + n_index * time_size * channel_size * hw_size + + (cur_segment_index / loop_time % channel_size) * hw_size; + if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) { + dst_gdram2nram = data_nram; + src_gdram2nram = + input + + (index - t_shift * channel_size * hw_size < first_index_cur_channel + ? first_index_cur_channel + : index - t_shift * channel_size * hw_size); + count_gdram2nram = res_segment - 1; + count_nram2gdram = res_segment - 1; + if (cur_sequence_index < t_shift && t_shift < next_sequence_index) { + dst_gdram2nram = + data_nram + t_shift % segmentime_size * hw_size * sizeof(T); + count_gdram2nram = res_segment - (t_shift - cur_sequence_index) - 1; + } + } else { + if (t_shift >= next_sequence_index) { + __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, + channel_size * hw_size * sizeof(T), hw_size * sizeof(T), + segmentime_size - 1); + continue; + } else if (cur_sequence_index < t_shift && + t_shift < next_sequence_index) { + dst_gdram2nram = + data_nram + t_shift % segmentime_size * hw_size * sizeof(T); + src_gdram2nram = input + first_index_cur_channel; + count_gdram2nram = segmentime_size - (t_shift % segmentime_size) - 1; + count_nram2gdram = segmentime_size - 1; + } else { + dst_gdram2nram = data_nram; + src_gdram2nram = input + index - t_shift * channel_size * hw_size; + count_gdram2nram = segmentime_size - 1; + count_nram2gdram = segmentime_size - 1; + } + } + } else { + int offset_index = time_size + t_shift; + if (cur_sequence_index >= offset_index) { + if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) { + __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, + channel_size * hw_size * sizeof(T), hw_size * sizeof(T), + res_segment - 1); + continue; + } else { + __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, + channel_size * hw_size * sizeof(T), hw_size * sizeof(T), + segmentime_size - 1); + continue; + } + } else { + dst_gdram2nram = data_nram; + src_gdram2nram = input + index - t_shift * channel_size * hw_size; + if (cur_sequence_index - t_shift + segmentime_size < time_size) { + count_gdram2nram = segmentime_size - 1; + count_nram2gdram = segmentime_size - 1; + } else { + count_gdram2nram = time_size - (cur_sequence_index - t_shift) - 1; + count_nram2gdram = + (segmentime_size - 1) < (time_size - cur_sequence_index - 1) + ? (segmentime_size - 1) + : (time_size - cur_sequence_index - 1); + } + } + } + __memcpy(dst_gdram2nram, src_gdram2nram, hw_size * sizeof(T), GDRAM2NRAM, + hw_size * sizeof(T), channel_size * hw_size * sizeof(T), + count_gdram2nram); + __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, + channel_size * hw_size * sizeof(T), hw_size * sizeof(T), + count_nram2gdram); + __asm__ volatile("sync;"); + } +} + +__mlu_entry__ void MLUUnion1KernelTinShift( + const void *input, const void *shifts, void *output, const int batch_size, + const int time_size, const int channel_size, const int hw_size, + const int group_size, const int group_channel, + const cnrtDataType_t data_dtype) { + // make sure that memcore is not used + if (coreId == 0x80) { + return; + } + switch (data_dtype) { + case CNRT_FLOAT16: { + mluMultiKernelTinShift((half *)input, (const int *)shifts, (half *)output, + batch_size, time_size, channel_size, hw_size, + group_size, group_channel); + }; break; + case CNRT_FLOAT32: { + mluMultiKernelTinShift((float *)input, (const int *)shifts, + (float *)output, batch_size, time_size, + channel_size, hw_size, group_size, group_channel); + }; break; + default: { return; } + } +} + +__mlu_entry__ void MLUUnion1KernelTinShiftSplitSequence( + const void *input, const void *shifts, void *output, const int batch_size, + const int time_size, const int channel_size, const int hw_size, + const int group_size, const int group_channel, + const int max_number_hw_per_core, const int max_length_per_core, + const cnrtDataType_t data_dtype) { + // make sure that memcore is not used + if (coreId == 0x80) { + return; + } + switch (data_dtype) { + case CNRT_FLOAT16: { + mluMultiKernelTinShiftSplitSequence( + (half *)input, (const int *)shifts, (half *)output, batch_size, + time_size, channel_size, hw_size, group_size, group_channel, + max_number_hw_per_core, max_length_per_core); + }; break; + case CNRT_FLOAT32: { + mluMultiKernelTinShiftSplitSequence( + (float *)input, (const int *)shifts, (float *)output, batch_size, + time_size, channel_size, hw_size, group_size, group_channel, + max_number_hw_per_core, max_length_per_core); + }; break; + default: { return; } + } +} + +void KernelTinShiftForward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const void *input, const void *shifts, void *output, const int batch_size, + const int time_size, const int channel_size, const int hw_size, + const int group_size, const int group_channel, + const cnrtDataType_t data_dtype, const int channel_per_core, + const int max_number_hw_per_core, const int max_length_per_core) { + if (channel_per_core >= 1) { + MLUUnion1KernelTinShift<<>>( + input, shifts, output, batch_size, time_size, channel_size, hw_size, + group_size, group_channel, data_dtype); + } else { + MLUUnion1KernelTinShiftSplitSequence<<>>( + input, shifts, output, batch_size, time_size, channel_size, hw_size, + group_size, group_channel, max_number_hw_per_core, max_length_per_core, + data_dtype); + } +} + +void KernelTinShiftBackward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const void *grad_output, const void *shifts, void *grad_input, + const int batch_size, const int time_size, const int channel_size, + const int hw_size, const int group_size, const int group_channel, + const cnrtDataType_t data_dtype, const int channel_per_core, + const int max_number_hw_per_core, const int max_length_per_core) { + if (channel_per_core >= 1) { + MLUUnion1KernelTinShift<<>>( + grad_output, shifts, grad_input, batch_size, time_size, channel_size, + hw_size, group_size, group_channel, data_dtype); + } else { + MLUUnion1KernelTinShiftSplitSequence<<>>( + grad_output, shifts, grad_input, batch_size, time_size, channel_size, + hw_size, group_size, group_channel, max_number_hw_per_core, + max_length_per_core, data_dtype); + } +} diff --git a/mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp new file mode 100644 index 0000000000..5cc79df5bc --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp @@ -0,0 +1,185 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "pytorch_mlu_helper.hpp" + +void KernelTinShiftForward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const void *input, const void *shifts, void *output, const int batch_size, + const int time_size, const int channel_size, const int hw_size, + const int group_size, const int group_channel, + const cnrtDataType_t data_dtype, const int channel_per_core, + const int max_number_hw_per_core, const int max_length_per_core); + +void KernelTinShiftBackward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const void *grad_output, const void *shifts, void *grad_input, + const int batch_size, const int time_size, const int channel_size, + const int hw_size, const int group_size, const int group_channel, + const cnrtDataType_t data_dtype, const int channel_per_core, + const int max_number_hw_per_core, const int max_length_per_core); + +// policy function +static void policyFunc(const Tensor &input, cnrtDim3_t *k_dim, + cnrtFunctionType_t *k_type, int *channel_per_core, + int *max_number_hw_per_core, int *max_length_per_core) { + const int32_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + const int32_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + auto nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); + const int core_num = core_limit * cluster_limit; + const int batch_size = input.size(0); + const int time_size = input.size(1); + const int channel_size = input.size(2); + const int hw_size = input.size(3); + + const size_t size_per_channel = time_size * hw_size * input.itemsize(); + *channel_per_core = nram_size / size_per_channel; + int task_dim = 0; + if (*channel_per_core == 0) { + const size_t size_per_hw = hw_size * input.itemsize(); + *max_number_hw_per_core = nram_size / size_per_hw; + if (*max_number_hw_per_core <= 0) { + *max_length_per_core = nram_size / input.itemsize(); + } + int tmp_max_number_hw_per_core = + *max_number_hw_per_core > 0 ? *max_number_hw_per_core : 1; + const int loop_time = + (time_size / (tmp_max_number_hw_per_core)) + + ((time_size % (tmp_max_number_hw_per_core)) > 0 ? 1 : 0); + task_dim = batch_size * channel_size * loop_time < core_num + ? batch_size * channel_size * loop_time + : core_num; + } else { + task_dim = batch_size * channel_size < core_num ? batch_size * channel_size + : core_num; + } + + k_dim->x = core_limit; + k_dim->y = (task_dim / core_limit) > 0 ? (task_dim / core_limit) : 1; + k_dim->z = 1; + *k_type = CNRT_FUNC_TYPE_UNION1; +} + +void TINShiftForwardMLUKernelLauncher(Tensor input, Tensor shift, + Tensor output) { + // params check + TORCH_CHECK( + input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, + "input type should be Float or Half, got ", input.scalar_type(), "."); + TORCH_CHECK(input.dim() == 4, "input should be a 4d tensor, got ", + input.dim(), "d."); + TORCH_CHECK(shift.dim() == 2, "shift should be a 2d tensor, got ", + shift.dim(), "d."); + TORCH_CHECK( + input.size(0) == shift.size(0), + "input batch size should be the same as shift's, input batch size is ", + input.size(0), " and shift batch size is ", shift.size(0), "."); + TORCH_CHECK(input.size(0) != 0, "Input batch size should not be zero."); + TORCH_CHECK(input.size(3) != 0, + "The last dim size of input should not be zero."); + if (input.size(1) == 0) { + return; + } + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + int channel_per_core = 0; + int max_number_hw_per_core = 0; + int max_length_per_core = 0; + policyFunc(input, &k_dim, &k_type, &channel_per_core, &max_number_hw_per_core, + &max_length_per_core); + + const int batch_size = input.size(0); + const int time_size = input.size(1); + const int channel_size = input.size(2); + const int hw_size = input.size(3); + const int group_size = shift.size(1); + int group_channel = channel_size / group_size; + + // get tensor impl + auto input_impl = torch_mlu::getMluTensorImpl(input); + auto shift_impl = torch_mlu::getMluTensorImpl(shift); + auto output_impl = torch_mlu::getMluTensorImpl(output); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get the mlu ptr + auto input_ptr = input_impl->cnnlMalloc(); + auto shift_ptr = shift_impl->cnnlMalloc(); + auto output_ptr = output_impl->cnnlMalloc(); + + cnrtDataType_t data_dtype = torch_mlu::toCnrtDtype(input.dtype()); + + KernelTinShiftForward(k_dim, k_type, queue, input_ptr, shift_ptr, output_ptr, + batch_size, time_size, channel_size, hw_size, + group_size, group_channel, data_dtype, channel_per_core, + max_number_hw_per_core, max_length_per_core); +} + +void TINShiftBackwardMLUKernelLauncher(Tensor grad_output, Tensor shift, + Tensor grad_input) { + // params check + TORCH_CHECK(grad_output.scalar_type() == at::kFloat || + grad_output.scalar_type() == at::kHalf, + "grad_output type should be Float or Half, got ", + grad_output.scalar_type(), "."); + TORCH_CHECK(grad_output.dim() == 4, "grad_output should be a 4d tensor, got ", + grad_output.dim(), "d."); + TORCH_CHECK(shift.dim() == 2, "shift should be a 2d tensor, got ", + shift.dim(), "d."); + TORCH_CHECK(grad_output.size(0) == shift.size(0), + "grad_output batch size should be the same as shift's, " + "grad_output batch size is ", + grad_output.size(0), ", shift batch size is ", shift.size(0), + "."); + TORCH_CHECK(grad_output.size(0) != 0, + "grad_output batch size should not be zero."); + TORCH_CHECK(grad_output.size(3) != 0, + "The last dim size of grad_output should not be zero."); + if (grad_output.size(1) == 0) { + return; + } + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + int channel_per_core = 0; + int max_number_hw_per_core = 0; + int max_length_per_core = 0; + policyFunc(grad_output, &k_dim, &k_type, &channel_per_core, + &max_number_hw_per_core, &max_length_per_core); + + const int batch_size = grad_output.size(0); + const int time_size = grad_output.size(1); + const int channel_size = grad_output.size(2); + const int hw_size = grad_output.size(3); + const int group_size = shift.size(1); + int group_channel = channel_size / group_size; + + // get tensor impl + auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output); + auto shift_impl = torch_mlu::getMluTensorImpl(shift); + auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get the mlu ptr + auto grad_output_ptr = grad_output_impl->cnnlMalloc(); + auto shift_ptr = shift_impl->cnnlMalloc(); + auto grad_input_ptr = grad_input_impl->cnnlMalloc(); + + cnrtDataType_t data_dtype = torch_mlu::toCnrtDtype(grad_output.dtype()); + + KernelTinShiftBackward(k_dim, k_type, queue, grad_output_ptr, shift_ptr, + grad_input_ptr, batch_size, time_size, channel_size, + hw_size, group_size, group_channel, data_dtype, + channel_per_core, max_number_hw_per_core, + max_length_per_core); +} diff --git a/mmcv/ops/csrc/pytorch/tin_shift.cpp b/mmcv/ops/csrc/pytorch/tin_shift.cpp index a10af24d3c..0df7ae3624 100644 --- a/mmcv/ops/csrc/pytorch/tin_shift.cpp +++ b/mmcv/ops/csrc/pytorch/tin_shift.cpp @@ -19,6 +19,24 @@ void tin_shift_backward_cuda(Tensor grad_output, Tensor shift, #endif +#ifdef MMCV_WITH_MLU +void TINShiftForwardMLUKernelLauncher(Tensor input, Tensor shift, + Tensor output); + +void TINShiftBackwardMLUKernelLauncher(Tensor grad_output, Tensor shift, + Tensor grad_input); + +void tin_shift_forward_mlu(Tensor input, Tensor shift, Tensor output) { + TINShiftForwardMLUKernelLauncher(input, shift, output); +} + +void tin_shift_backward_mlu(Tensor grad_output, Tensor shift, + Tensor grad_input) { + TINShiftBackwardMLUKernelLauncher(grad_output, shift, grad_input); +} + +#endif + void tin_shift_forward(Tensor input, Tensor shift, Tensor output) { if (input.device().is_cuda()) { #ifdef MMCV_WITH_CUDA @@ -29,6 +47,14 @@ void tin_shift_forward(Tensor input, Tensor shift, Tensor output) { tin_shift_forward_cuda(input, shift, output); #else AT_ERROR("TINShift is not compiled with GPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (input.device().type() == at::kMLU) { + CHECK_MLU_INPUT(input); + CHECK_MLU_INPUT(shift); + CHECK_MLU_INPUT(output); + + tin_shift_forward_mlu(input, shift, output); #endif } else { AT_ERROR("TINShift is not implemented on CPU"); @@ -45,6 +71,14 @@ void tin_shift_backward(Tensor grad_output, Tensor shift, Tensor grad_input) { tin_shift_backward_cuda(grad_output, shift, grad_input); #else AT_ERROR("TINShift is not compiled with GPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (grad_output.device().type() == at::kMLU) { + CHECK_MLU_INPUT(grad_output); + CHECK_MLU_INPUT(shift); + CHECK_MLU_INPUT(grad_input); + + tin_shift_backward_mlu(grad_output, shift, grad_input); #endif } else { AT_ERROR("TINShift is not implemented on CPU"); diff --git a/mmcv/ops/tin_shift.py b/mmcv/ops/tin_shift.py old mode 100644 new mode 100755 index 472c9fcfe4..520e81c316 --- a/mmcv/ops/tin_shift.py +++ b/mmcv/ops/tin_shift.py @@ -18,6 +18,10 @@ class TINShiftFunction(Function): @staticmethod def forward(ctx, input, shift): + if input.size(0) != shift.size(0): + raise ValueError( + 'The first dim (batch) of `input` and `shift` should be same, ' + f'but got {input.size(0)} and {shift.size(0)}.') C = input.size(2) num_segments = shift.size(1) if C // num_segments <= 0 or C % num_segments != 0: diff --git a/tests/test_ops/test_tin_shift.py b/tests/test_ops/test_tin_shift.py old mode 100644 new mode 100755 index 93cea6ea58..f072624124 --- a/tests/test_ops/test_tin_shift.py +++ b/tests/test_ops/test_tin_shift.py @@ -4,6 +4,8 @@ import pytest import torch +from mmcv.utils import is_cuda, is_mlu + _USING_PARROTS = True try: from parrots.autograd import gradcheck @@ -130,7 +132,7 @@ ] -def _test_tinshift_gradcheck(dtype): +def _test_tinshift_gradcheck(device, dtype): try: from mmcv.ops import tin_shift except ModuleNotFoundError: @@ -144,15 +146,15 @@ def _test_tinshift_gradcheck(dtype): np_shift = np.array(shift) x = torch.tensor( - np_input, dtype=dtype, device='cuda', requires_grad=True) - shift = torch.tensor(np_shift, device='cuda').int() + np_input, dtype=dtype, device=device, requires_grad=True) + shift = torch.tensor(np_shift, device=device).int() if torch.__version__ == 'parrots': gradcheck(tin_shift, (x, shift)) else: gradcheck(tin_shift, (x, shift), atol=1, rtol=0.1) -def _test_tinshift_allclose(dtype): +def _test_tinshift_allclose(device, dtype): try: from mmcv.ops import tin_shift except ModuleNotFoundError: @@ -165,8 +167,8 @@ def _test_tinshift_allclose(dtype): np_grad = np.array(grad) x = torch.tensor( - np_input, dtype=dtype, device='cuda', requires_grad=True) - shift = torch.tensor(np_shift, device='cuda').int() + np_input, dtype=dtype, device=device, requires_grad=True) + shift = torch.tensor(np_shift, device=device).int() output = tin_shift(x, shift) output.backward(torch.ones_like(output)) @@ -176,28 +178,47 @@ def _test_tinshift_allclose(dtype): x.grad.data.type(torch.float).cpu().numpy(), np_grad, 1e-3) -def _test_tinshift_assert(dtype): +def _test_tinshift_assert(device, dtype): try: from mmcv.ops import tin_shift except ModuleNotFoundError: pytest.skip('TINShift op is not successfully compiled') - inputs = [torch.rand(2, 3, 4, 2), torch.rand(2, 3, 4, 2)] + inputs = [ + torch.rand(2, 3, 4, 2), + torch.rand(2, 3, 4, 2), + torch.rand(1, 3, 4, 2) + ] shifts = [torch.rand(2, 3), torch.rand(2, 5)] for x, shift in zip(inputs, shifts): - x = x.cuda() - shift = shift.cuda() + x = x.to(device).type(dtype) + shift = shift.to(device).type(dtype) # A ValueError should be raised if ops get inputs with wrong shapes. with pytest.raises(ValueError): tin_shift(x, shift) -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -@pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half]) -def test_tinshift(dtype): - _test_tinshift_allclose(dtype=dtype) - _test_tinshift_gradcheck(dtype=dtype) - _test_tinshift_assert(dtype=dtype) +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not is_cuda(), reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif(not is_mlu(), reason='requires MLU support')) +]) +@pytest.mark.parametrize('dtype', [ + torch.float, + pytest.param( + torch.double, + marks=pytest.mark.skipif( + is_mlu(), + reason='MLU does not support for 64-bit floating point')), + torch.half +]) +def test_tinshift(device, dtype): + _test_tinshift_allclose(device=device, dtype=dtype) + _test_tinshift_gradcheck(device=device, dtype=dtype) + _test_tinshift_assert(device=device, dtype=dtype) From f0f4949fe8f9fb633b500aa60583e6438a38e357 Mon Sep 17 00:00:00 2001 From: jzwang <841713301@qq.com> Date: Thu, 7 Apr 2022 12:06:28 +0800 Subject: [PATCH 23/30] resolve conflicts and fix lint --- mmcv/ops/tin_shift.py | 4 ++-- mmcv/utils/__init__.py | 20 ++++++++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/mmcv/ops/tin_shift.py b/mmcv/ops/tin_shift.py index 520e81c316..94d7e036e0 100755 --- a/mmcv/ops/tin_shift.py +++ b/mmcv/ops/tin_shift.py @@ -20,8 +20,8 @@ class TINShiftFunction(Function): def forward(ctx, input, shift): if input.size(0) != shift.size(0): raise ValueError( - 'The first dim (batch) of `input` and `shift` should be same, ' - f'but got {input.size(0)} and {shift.size(0)}.') + f'The first dim (batch) of `input` and `shift` should be' + f'same, but got {input.size(0)} and {shift.size(0)}.') C = input.size(2) num_segments = shift.size(1) if C // num_segments <= 0 or C % num_segments != 0: diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index b8d9afde74..8159c6a1a7 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -37,17 +37,21 @@ ] else: from .env import collect_env + from .hub import load_url from .logging import get_logger, print_log from .parrots_jit import jit, skip_no_elena - from .parrots_wrapper import ( - TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader, - PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, - _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm, - _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home) - from .pytorch_wrapper import is_cuda, is_mlu + # yapf: disable + from .parrots_wrapper import (TORCH_VERSION, BuildExtension, CppExtension, + CUDAExtension, DataLoader, PoolDataLoader, + SyncBatchNorm, _AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, + _ConvNd, _ConvTransposeMixin, _get_cuda_home, + _InstanceNorm, _MaxPoolNd, get_build_config, + is_rocm_pytorch) + # yapf: enable from .registry import Registry, build_from_cfg + from .seed import worker_init_fn from .trace import is_jit_tracing - from .hub import load_url __all__ = [ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast', @@ -67,5 +71,5 @@ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', - '_get_cuda_home', 'load_url', 'has_method' + '_get_cuda_home', 'load_url', 'has_method', 'worker_init_fn' ] From 70770a02acea18135d735e55ad762503b9b30000 Mon Sep 17 00:00:00 2001 From: jzwang <841713301@qq.com> Date: Thu, 7 Apr 2022 12:31:01 +0800 Subject: [PATCH 24/30] fix mmcv.utils.__init__ --- mmcv/utils/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 8159c6a1a7..019e77f568 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -50,7 +50,6 @@ is_rocm_pytorch) # yapf: enable from .registry import Registry, build_from_cfg - from .seed import worker_init_fn from .trace import is_jit_tracing __all__ = [ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', @@ -71,5 +70,5 @@ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', - '_get_cuda_home', 'load_url', 'has_method', 'worker_init_fn' + '_get_cuda_home', 'load_url', 'has_method' ] From 587dc8120be87896964a52df8f852d662aa87ff3 Mon Sep 17 00:00:00 2001 From: jzwang <841713301@qq.com> Date: Thu, 7 Apr 2022 13:14:27 +0800 Subject: [PATCH 25/30] fix mmcv.utils.__init__ --- mmcv/utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 019e77f568..7dc2c5bbd3 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -49,6 +49,7 @@ _InstanceNorm, _MaxPoolNd, get_build_config, is_rocm_pytorch) # yapf: enable + from .pytorch_wrapper import is_cuda, is_mlu from .registry import Registry, build_from_cfg from .trace import is_jit_tracing __all__ = [ @@ -70,5 +71,5 @@ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', - '_get_cuda_home', 'load_url', 'has_method' + '_get_cuda_home', 'load_url', 'has_method', 'is_cuda', 'is_mlu' ] From 135e77d6a5b0311b22bde6b77a74a24ce45068c0 Mon Sep 17 00:00:00 2001 From: jzwang <841713301@qq.com> Date: Fri, 15 Apr 2022 15:29:07 +0800 Subject: [PATCH 26/30] Fix lints and change FLAG --- mmcv/device/mlu/__init__.py | 4 ++-- mmcv/device/mlu/utils.py | 2 +- mmcv/runner/dist_utils.py | 4 ++-- mmcv/utils/__init__.py | 5 +++-- mmcv/utils/logging.py | 2 +- mmcv/utils/pytorch_wrapper.py | 5 ++++- tests/test_device/test_mlu/test_mlu_parallel.py | 9 +++++---- tests/test_ops/test_bbox.py | 12 ++++++------ tests/test_ops/test_focal_loss.py | 12 ++++++------ tests/test_ops/test_nms.py | 8 ++++---- tests/test_ops/test_roi_align.py | 12 +++++++----- tests/test_ops/test_tin_shift.py | 12 +++++++----- 12 files changed, 48 insertions(+), 39 deletions(-) diff --git a/mmcv/device/mlu/__init__.py b/mmcv/device/mlu/__init__.py index 92681c5ef0..ffec020e94 100644 --- a/mmcv/device/mlu/__init__.py +++ b/mmcv/device/mlu/__init__.py @@ -2,9 +2,9 @@ from .data_parallel import MLUDataParallel from .distributed import MLUDistributedDataParallel from .scatter_gather import scatter, scatter_kwargs -from .utils import IS_MLU +from .utils import IS_MLU_AVAILABLE __all__ = [ 'MLUDataParallel', 'MLUDistributedDataParallel', 'scatter', - 'scatter_kwargs', 'IS_MLU' + 'scatter_kwargs', 'IS_MLU_AVAILABLE' ] diff --git a/mmcv/device/mlu/utils.py b/mmcv/device/mlu/utils.py index 4158e6d441..ff9ea60c7e 100644 --- a/mmcv/device/mlu/utils.py +++ b/mmcv/device/mlu/utils.py @@ -8,4 +8,4 @@ def is_mlu_available(): return False -IS_MLU = is_mlu_available() +IS_MLU_AVAILABLE = is_mlu_available() diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py index 8a60dcc418..0914dd670d 100644 --- a/mmcv/runner/dist_utils.py +++ b/mmcv/runner/dist_utils.py @@ -12,7 +12,7 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors) -from mmcv.device.mlu import IS_MLU +from mmcv.device.mlu import IS_MLU_AVAILABLE def _find_free_port(): @@ -49,7 +49,7 @@ def init_dist(launcher, backend='nccl', **kwargs): def _init_dist_pytorch(backend, **kwargs): # TODO: use local_rank instead of rank % num_gpus rank = int(os.environ['RANK']) - if IS_MLU: + if IS_MLU_AVAILABLE: import torch_mlu # noqa: F401 torch.mlu.set_device(rank) dist.init_process_group( diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 77686a8c6e..5a02b96d7e 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -49,7 +49,7 @@ _InstanceNorm, _MaxPoolNd, get_build_config, is_rocm_pytorch) # yapf: enable - from .pytorch_wrapper import is_cuda + from .pytorch_wrapper import IS_CUDA_AVAILABLE from .registry import Registry, build_from_cfg from .seed import worker_init_fn from .trace import is_jit_tracing @@ -72,5 +72,6 @@ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', - '_get_cuda_home', 'load_url', 'has_method', 'is_cuda', 'worker_init_fn' + '_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE', + 'worker_init_fn' ] diff --git a/mmcv/utils/logging.py b/mmcv/utils/logging.py index 84dfc3303d..c4c7025f0e 100644 --- a/mmcv/utils/logging.py +++ b/mmcv/utils/logging.py @@ -89,7 +89,7 @@ def print_log(msg, logger=None, level=logging.INFO): msg (str): The message to be logged. logger (logging.Logger | str | None): The logger to be used. Some special loggers are: - + - "silent": no message will be printed. - other str: the logger obtained with `get_root_logger(logger)`. - None: The `print()` method will be used to print log messages. diff --git a/mmcv/utils/pytorch_wrapper.py b/mmcv/utils/pytorch_wrapper.py index d1c22aa96f..eb720ec4c8 100644 --- a/mmcv/utils/pytorch_wrapper.py +++ b/mmcv/utils/pytorch_wrapper.py @@ -4,5 +4,8 @@ TORCH_VERSION = torch.__version__ -def is_cuda() -> bool: +def is_cuda_available() -> bool: return torch.cuda.is_available() + + +IS_CUDA_AVAILABLE = is_cuda_available() diff --git a/tests/test_device/test_mlu/test_mlu_parallel.py b/tests/test_device/test_mlu/test_mlu_parallel.py index 7db9b495ca..3d51a68f82 100644 --- a/tests/test_device/test_mlu/test_mlu_parallel.py +++ b/tests/test_device/test_mlu/test_mlu_parallel.py @@ -5,7 +5,8 @@ import torch import torch.nn as nn -from mmcv.device.mlu import IS_MLU, MLUDataParallel, MLUDistributedDataParallel +from mmcv.device.mlu import (IS_MLU_AVAILABLE, MLUDataParallel, + MLUDistributedDataParallel) from mmcv.device.mlu._functions import Scatter, scatter from mmcv.parallel import is_module_wrapper @@ -31,7 +32,7 @@ def forward(self, x): model = Model() assert not is_module_wrapper(model) - if IS_MLU: + if IS_MLU_AVAILABLE: mludp = MLUDataParallel(model) assert is_module_wrapper(mludp) @@ -51,7 +52,7 @@ def test_scatter(): assert torch.allclose(input, output) # if the device is MLU, copy the input from CPU to MLU - if IS_MLU: + if IS_MLU_AVAILABLE: input = torch.zeros([1, 3, 3, 3]) output = scatter(input=input, devices=[0]) assert torch.allclose(input.to('mlu'), output) @@ -82,7 +83,7 @@ def test_Scatter(): assert torch.allclose(input, output) # if the device is MLU, copy the input from CPU to MLU - if IS_MLU: + if IS_MLU_AVAILABLE: target_mlus = [0] input = torch.zeros([1, 3, 3, 3]) outputs = Scatter.forward(target_mlus, input) diff --git a/tests/test_ops/test_bbox.py b/tests/test_ops/test_bbox.py index 0ad6c04bb1..f75100ccf8 100644 --- a/tests/test_ops/test_bbox.py +++ b/tests/test_ops/test_bbox.py @@ -3,8 +3,8 @@ import pytest import torch -from mmcv.device.mlu import IS_MLU -from mmcv.utils import is_cuda +from mmcv.device.mlu import IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE class TestBBox(object): @@ -39,11 +39,11 @@ def _test_bbox_overlaps(self, device, dtype=torch.float): pytest.param( 'cuda', marks=pytest.mark.skipif( - not is_cuda(), reason='requires CUDA support')), + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) def test_bbox_overlaps_float(self, device): self._test_bbox_overlaps(device, dtype=torch.float) @@ -52,11 +52,11 @@ def test_bbox_overlaps_float(self, device): pytest.param( 'cuda', marks=pytest.mark.skipif( - not is_cuda(), reason='requires CUDA support')), + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) def test_bbox_overlaps_half(self, device): self._test_bbox_overlaps(device, dtype=torch.half) diff --git a/tests/test_ops/test_focal_loss.py b/tests/test_ops/test_focal_loss.py index 441e955a28..c28e640874 100644 --- a/tests/test_ops/test_focal_loss.py +++ b/tests/test_ops/test_focal_loss.py @@ -3,8 +3,8 @@ import pytest import torch -from mmcv.device.mlu import IS_MLU -from mmcv.utils import is_cuda +from mmcv.device.mlu import IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE _USING_PARROTS = True try: @@ -134,11 +134,11 @@ def test_softmax_half(self): pytest.param( 'cuda', marks=pytest.mark.skipif( - not is_cuda(), reason='requires CUDA support')), + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) def test_sigmoid_float(self, device): self._test_sigmoid(device=device, dtype=torch.float) @@ -147,11 +147,11 @@ def test_sigmoid_float(self, device): pytest.param( 'cuda', marks=pytest.mark.skipif( - not is_cuda(), reason='requires CUDA support')), + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) def test_sigmoid_half(self, device): self._test_sigmoid(device, dtype=torch.half) diff --git a/tests/test_ops/test_nms.py b/tests/test_ops/test_nms.py index 485453cf52..d90f0afb17 100644 --- a/tests/test_ops/test_nms.py +++ b/tests/test_ops/test_nms.py @@ -3,8 +3,8 @@ import pytest import torch -from mmcv.device.mlu import IS_MLU -from mmcv.utils import is_cuda +from mmcv.device.mlu import IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE class Testnms(object): @@ -13,11 +13,11 @@ class Testnms(object): pytest.param( 'cuda', marks=pytest.mark.skipif( - not is_cuda(), reason='requires CUDA support')), + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) def test_nms_allclose(self, device): from mmcv.ops import nms diff --git a/tests/test_ops/test_roi_align.py b/tests/test_ops/test_roi_align.py index 83bdc282f2..cc6f92b8cb 100644 --- a/tests/test_ops/test_roi_align.py +++ b/tests/test_ops/test_roi_align.py @@ -3,8 +3,8 @@ import pytest import torch -from mmcv.device.mlu import IS_MLU -from mmcv.utils import is_cuda +from mmcv.device.mlu import IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE _USING_PARROTS = True try: @@ -99,17 +99,19 @@ def _test_roialign_allclose(device, dtype): pytest.param( 'cuda', marks=pytest.mark.skipif( - not is_cuda(), reason='requires CUDA support')), + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), pytest.param( 'mlu', - marks=pytest.mark.skipif(not IS_MLU, reason='requires MLU support')) + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) @pytest.mark.parametrize('dtype', [ torch.float, pytest.param( torch.double, marks=pytest.mark.skipif( - IS_MLU, reason='MLU does not support for 64-bit floating point')), + IS_MLU_AVAILABLE, + reason='MLU does not support for 64-bit floating point')), torch.half ]) def test_roialign(device, dtype): diff --git a/tests/test_ops/test_tin_shift.py b/tests/test_ops/test_tin_shift.py index 2eeb43c91e..7a32aed71e 100755 --- a/tests/test_ops/test_tin_shift.py +++ b/tests/test_ops/test_tin_shift.py @@ -5,8 +5,8 @@ import pytest import torch -from mmcv.device.mlu import IS_MLU -from mmcv.utils import is_cuda +from mmcv.device.mlu import IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE _USING_PARROTS = True try: @@ -206,17 +206,19 @@ def _test_tinshift_assert(device, dtype): pytest.param( 'cuda', marks=pytest.mark.skipif( - not is_cuda(), reason='requires CUDA support')), + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), pytest.param( 'mlu', - marks=pytest.mark.skipif(not IS_MLU, reason='requires MLU support')) + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) @pytest.mark.parametrize('dtype', [ torch.float, pytest.param( torch.double, marks=pytest.mark.skipif( - IS_MLU, reason='MLU does not support for 64-bit floating point')), + IS_MLU_AVAILABLE, + reason='MLU does not support for 64-bit floating point')), torch.half ]) def test_tinshift(device, dtype): From a67c28ded514c4d55b32fc71c48fad5f3c2186de Mon Sep 17 00:00:00 2001 From: jzwang <841713301@qq.com> Date: Fri, 15 Apr 2022 17:18:45 +0800 Subject: [PATCH 27/30] fix setup and refine --- mmcv/utils/__init__.py | 12 ++++++------ mmcv/utils/parrots_wrapper.py | 7 +++++++ mmcv/utils/pytorch_wrapper.py | 11 ----------- setup.py | 6 +----- 4 files changed, 14 insertions(+), 22 deletions(-) delete mode 100644 mmcv/utils/pytorch_wrapper.py diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 5a02b96d7e..774259b4e2 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -41,15 +41,15 @@ from .logging import get_logger, print_log from .parrots_jit import jit, skip_no_elena # yapf: disable - from .parrots_wrapper import (TORCH_VERSION, BuildExtension, CppExtension, - CUDAExtension, DataLoader, PoolDataLoader, - SyncBatchNorm, _AdaptiveAvgPoolNd, - _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, - _ConvNd, _ConvTransposeMixin, _get_cuda_home, + from .parrots_wrapper import (IS_CUDA_AVAILABLE, TORCH_VERSION, + BuildExtension, CppExtension, CUDAExtension, + DataLoader, PoolDataLoader, SyncBatchNorm, + _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, + _AvgPoolNd, _BatchNorm, _ConvNd, + _ConvTransposeMixin, _get_cuda_home, _InstanceNorm, _MaxPoolNd, get_build_config, is_rocm_pytorch) # yapf: enable - from .pytorch_wrapper import IS_CUDA_AVAILABLE from .registry import Registry, build_from_cfg from .seed import worker_init_fn from .trace import is_jit_tracing diff --git a/mmcv/utils/parrots_wrapper.py b/mmcv/utils/parrots_wrapper.py index cf6bf9ef86..7e657b5616 100644 --- a/mmcv/utils/parrots_wrapper.py +++ b/mmcv/utils/parrots_wrapper.py @@ -6,6 +6,13 @@ TORCH_VERSION = torch.__version__ +def is_cuda_available() -> bool: + return torch.cuda.is_available() + + +IS_CUDA_AVAILABLE = is_cuda_available() + + def is_rocm_pytorch() -> bool: is_rocm = False if TORCH_VERSION != 'parrots': diff --git a/mmcv/utils/pytorch_wrapper.py b/mmcv/utils/pytorch_wrapper.py deleted file mode 100644 index eb720ec4c8..0000000000 --- a/mmcv/utils/pytorch_wrapper.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -TORCH_VERSION = torch.__version__ - - -def is_cuda_available() -> bool: - return torch.cuda.is_available() - - -IS_CUDA_AVAILABLE = is_cuda_available() diff --git a/setup.py b/setup.py index fda95535e7..d043459d4c 100644 --- a/setup.py +++ b/setup.py @@ -17,11 +17,7 @@ from torch_mlu.utils.cpp_extension import BuildExtension EXT_TYPE = 'pytorch' else: - try: - if torch.is_mlu_available(): - from torch_mlu.utils.cpp_extension import BuildExtension - except AttributeError: - from torch.utils.cpp_extension import BuildExtension + from torch.utils.cpp_extension import BuildExtension EXT_TYPE = 'pytorch' cmd_class = {'build_ext': BuildExtension} except ModuleNotFoundError: From 1e1b26206136147b3ff7f0232630687520ecbdb4 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Fri, 15 Apr 2022 23:09:34 +0800 Subject: [PATCH 28/30] remove a redundant line --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index d043459d4c..7d35cd48f8 100644 --- a/setup.py +++ b/setup.py @@ -302,7 +302,6 @@ def get_extensions(): op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \ - glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu') extension = MLUExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu')) From 9150d25af1130fc08fa29202221f5a0d0f67195d Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sat, 16 Apr 2022 11:13:35 +0800 Subject: [PATCH 29/30] remove an unnecessary 'f' --- mmcv/ops/tin_shift.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/tin_shift.py b/mmcv/ops/tin_shift.py index 5d7b47d8fa..473231cc0d 100755 --- a/mmcv/ops/tin_shift.py +++ b/mmcv/ops/tin_shift.py @@ -20,7 +20,7 @@ class TINShiftFunction(Function): def forward(ctx, input, shift): if input.size(0) != shift.size(0): raise ValueError( - f'The first dim (batch) of `input` and `shift` should be' + 'The first dim (batch) of `input` and `shift` should be ' f'same, but got {input.size(0)} and {shift.size(0)}.') C = input.size(2) num_segments = shift.size(1) From 8d9c65d5dcd2e0c1aa1496ba3def8141b5f5cb62 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sat, 16 Apr 2022 11:16:15 +0800 Subject: [PATCH 30/30] fix compilation error --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7d35cd48f8..d34706a267 100644 --- a/setup.py +++ b/setup.py @@ -301,7 +301,7 @@ def get_extensions(): extra_compile_args['cncc'] = [mlu_args] if mlu_args else [] op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ - glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') extension = MLUExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))