From 030011c5205c5da3bd40d9311b7f32f20fc25720 Mon Sep 17 00:00:00 2001 From: grimoire Date: Sun, 26 Jun 2022 22:17:31 +0800 Subject: [PATCH 1/9] reorgnize mmrotate --- .../codebase/mmrotate/core/bbox/__init__.py | 1 + .../core/bbox/gliding_vertex_coder.py | 31 ++++ mmdeploy/codebase/mmrotate/models/__init__.py | 21 +-- .../mmrotate/models/dense_heads/__init__.py | 9 ++ .../models/dense_heads/oriented_rpn_head.py | 141 ++++++++++++++++++ .../{ => dense_heads}/rotated_anchor_head.py | 0 .../{ => dense_heads}/rotated_rpn_head.py | 16 +- .../mmrotate/models/roi_heads/__init__.py | 14 ++ .../mmrotate/models/roi_heads/gv_bbox_head.py | 85 +++++++++++ .../models/roi_heads/gv_ratio_roi_head.py | 73 +++++++++ .../oriented_standard_roi_head.py | 8 +- .../models/{ => roi_heads}/roi_extractors.py | 0 .../{ => roi_heads}/rotated_bbox_head.py | 0 13 files changed, 371 insertions(+), 28 deletions(-) create mode 100644 mmdeploy/codebase/mmrotate/core/bbox/gliding_vertex_coder.py create mode 100644 mmdeploy/codebase/mmrotate/models/dense_heads/__init__.py create mode 100644 mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py rename mmdeploy/codebase/mmrotate/models/{ => dense_heads}/rotated_anchor_head.py (100%) rename mmdeploy/codebase/mmrotate/models/{ => dense_heads}/rotated_rpn_head.py (93%) create mode 100644 mmdeploy/codebase/mmrotate/models/roi_heads/__init__.py create mode 100644 mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py create mode 100644 mmdeploy/codebase/mmrotate/models/roi_heads/gv_ratio_roi_head.py rename mmdeploy/codebase/mmrotate/models/{ => roi_heads}/oriented_standard_roi_head.py (93%) rename mmdeploy/codebase/mmrotate/models/{ => roi_heads}/roi_extractors.py (100%) rename mmdeploy/codebase/mmrotate/models/{ => roi_heads}/rotated_bbox_head.py (100%) diff --git a/mmdeploy/codebase/mmrotate/core/bbox/__init__.py b/mmdeploy/codebase/mmrotate/core/bbox/__init__.py index 22ef641430..d7f70b075a 100644 --- a/mmdeploy/codebase/mmrotate/core/bbox/__init__.py +++ b/mmdeploy/codebase/mmrotate/core/bbox/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .delta_midpointoffset_rbbox_coder import * # noqa: F401,F403 from .delta_xywha_rbbox_coder import * # noqa: F401,F403 +from .gliding_vertex_coder import * # noqa: F401,F403 from .transforms import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmrotate/core/bbox/gliding_vertex_coder.py b/mmdeploy/codebase/mmrotate/core/bbox/gliding_vertex_coder.py new file mode 100644 index 0000000000..3e7c07955a --- /dev/null +++ b/mmdeploy/codebase/mmrotate/core/bbox/gliding_vertex_coder.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmrotate.core.bbox.coder.gliding_vertex_coder' + '.GVFixCoder.decode') +def gvfixcoder__decode(ctx, self, hbboxes, fix_deltas): + """Rewriter for GVFixCoder decode, support more dimension input.""" + + from mmrotate.core.bbox.transforms import poly2obb + x1 = hbboxes[..., 0::4] + y1 = hbboxes[..., 1::4] + x2 = hbboxes[..., 2::4] + y2 = hbboxes[..., 3::4] + w = hbboxes[..., 2::4] - hbboxes[..., 0::4] + h = hbboxes[..., 3::4] - hbboxes[..., 1::4] + + pred_t_x = x1 + w * fix_deltas[..., 0::4] + pred_r_y = y1 + h * fix_deltas[..., 1::4] + pred_d_x = x2 - w * fix_deltas[..., 2::4] + pred_l_y = y2 - h * fix_deltas[..., 3::4] + + polys = torch.stack( + [pred_t_x, y1, x2, pred_r_y, pred_d_x, y2, x1, pred_l_y], dim=-1) + polys = polys.flatten(2) + rbboxes = poly2obb(polys, self.version) + + return rbboxes diff --git a/mmdeploy/codebase/mmrotate/models/__init__.py b/mmdeploy/codebase/mmrotate/models/__init__.py index 5fe7e5a1c7..65edb9dbae 100644 --- a/mmdeploy/codebase/mmrotate/models/__init__.py +++ b/mmdeploy/codebase/mmrotate/models/__init__.py @@ -1,19 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .oriented_standard_roi_head import ( - oriented_standard_roi_head__simple_test, - oriented_standard_roi_head__simple_test_bboxes) -from .roi_extractors import rotated_single_roi_extractor__forward__tensorrt -from .rotated_anchor_head import rotated_anchor_head__get_bbox -from .rotated_bbox_head import rotated_bbox_head__get_bboxes -from .rotated_rpn_head import rotated_rpn_head__get_bboxes -from .single_stage_rotated_detector import \ - single_stage_rotated_detector__simple_test - -__all__ = [ - 'single_stage_rotated_detector__simple_test', - 'rotated_anchor_head__get_bbox', 'rotated_rpn_head__get_bboxes', - 'oriented_standard_roi_head__simple_test', - 'oriented_standard_roi_head__simple_test_bboxes', - 'rotated_bbox_head__get_bboxes', - 'rotated_single_roi_extractor__forward__tensorrt' -] +from .dense_heads import * # noqa: F401,F403 +from .roi_heads import * # noqa: F401,F403 +from .single_stage_rotated_detector import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmrotate/models/dense_heads/__init__.py b/mmdeploy/codebase/mmrotate/models/dense_heads/__init__.py new file mode 100644 index 0000000000..90163f8351 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/dense_heads/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .oriented_rpn_head import oriented_rpn_head__get_bboxes +from .rotated_anchor_head import rotated_anchor_head__get_bbox +from .rotated_rpn_head import rotated_rpn_head__get_bboxes + +__all__ = [ + 'oriented_rpn_head__get_bboxes', 'rotated_anchor_head__get_bbox', + 'rotated_rpn_head__get_bboxes' +] diff --git a/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py b/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py new file mode 100644 index 0000000000..9b2adfc2f1 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.codebase.mmdet import (get_post_processing_params, + pad_with_value_if_necessary) +from mmdeploy.codebase.mmrotate.core.post_processing import \ + fake_multiclass_nms_rotated +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import is_dynamic_shape + + +@FUNCTION_REWRITER.register_rewriter( + 'mmrotate.models.dense_heads.OrientedRPNHead.get_bboxes') +def oriented_rpn_head__get_bboxes(ctx, + self, + cls_scores, + bbox_preds, + score_factors=None, + img_metas=None, + cfg=None, + rescale=False, + with_nms=True, + **kwargs): + """Rewrite `get_bboxes` of `RPNHead` for default backend. + + Rewrite this function to deploy model, transform network output for a + batch into bbox predictions. + + Args: + ctx (ContextCaller): The context with additional information. + self (FoveaHead): The instance of the class FoveaHead. + cls_scores (list[Tensor]): Box scores for each scale level + with shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + score_factors (list[Tensor], Optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Default None. + img_metas (list[dict]): Meta information of the image, e.g., + image size, scaling factor, etc. + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used. Default: None. + rescale (bool): If True, return boxes in original image space. + Default False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + Returns: + If with_nms == True: + tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels), + `dets` of shape [N, num_det, 5] and `labels` of shape + [N, num_det]. + Else: + tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores + """ + assert len(cls_scores) == len(bbox_preds) + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + num_levels = len(cls_scores) + + device = cls_scores[0].device + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_anchors = self.anchor_generator.grid_anchors( + featmap_sizes, device=device) + + mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] + mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] + assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors) + + cfg = self.test_cfg if cfg is None else cfg + batch_size = mlvl_cls_scores[0].shape[0] + pre_topk = cfg.get('nms_pre', -1) + + # loop over features, decode boxes + mlvl_valid_bboxes = [] + mlvl_scores = [] + mlvl_valid_anchors = [] + for level_id, cls_score, bbox_pred, anchors in zip( + range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + cls_score = cls_score.permute(0, 2, 3, 1) + if self.use_sigmoid_cls: + cls_score = cls_score.reshape(batch_size, -1) + scores = cls_score.sigmoid() + else: + cls_score = cls_score.reshape(batch_size, -1, 2) + # We set FG labels to [0, num_class-1] and BG label to + # num_class in RPN head since mmdet v2.5, which is unified to + # be consistent with other head since mmdet v2.0. In mmdet v2.0 + # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. + scores = cls_score.softmax(-1)[..., 0] + scores = scores.reshape(batch_size, -1, 1) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 6) + + # use static anchor if input shape is static + if not is_dynamic_flag: + anchors = anchors.data + + anchors = anchors.unsqueeze(0) + + # topk in tensorrt does not support shape 0: + _, topk_inds = scores.squeeze(2).topk(pre_topk) + batch_inds = torch.arange(batch_size, device=device).unsqueeze(-1) + prior_inds = topk_inds.new_zeros((1, 1)) + anchors = anchors[prior_inds, topk_inds, :] + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + mlvl_valid_bboxes.append(bbox_pred) + mlvl_scores.append(scores) + mlvl_valid_anchors.append(anchors) + + batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) + batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1) + batch_mlvl_bboxes = self.bbox_coder.decode( + batch_mlvl_anchors, + batch_mlvl_bboxes, + max_shape=img_metas[0]['img_shape']) + # ignore background class + if not self.use_sigmoid_cls: + batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes] + if not with_nms: + return batch_mlvl_bboxes, batch_mlvl_scores + + post_params = get_post_processing_params(deploy_cfg) + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + # only one class in rpn + max_output_boxes_per_class = keep_top_k + return fake_multiclass_nms_rotated( + batch_mlvl_bboxes, + batch_mlvl_scores, + max_output_boxes_per_class, + iou_threshold=iou_threshold, + keep_top_k=keep_top_k, + version=self.version) diff --git a/mmdeploy/codebase/mmrotate/models/rotated_anchor_head.py b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_anchor_head.py similarity index 100% rename from mmdeploy/codebase/mmrotate/models/rotated_anchor_head.py rename to mmdeploy/codebase/mmrotate/models/dense_heads/rotated_anchor_head.py diff --git a/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py similarity index 93% rename from mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py rename to mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py index adef7b5b90..d7c8524be8 100644 --- a/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py +++ b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py @@ -3,8 +3,9 @@ from mmdeploy.codebase.mmdet import (get_post_processing_params, pad_with_value_if_necessary) -from mmdeploy.codebase.mmrotate.core.post_processing import \ - fake_multiclass_nms_rotated +# from mmdeploy.codebase.mmrotate.core.post_processing import \ +# multiclass_nms_rotated +from mmdeploy.codebase.mmdet.core.post_processing import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.utils import is_dynamic_shape @@ -89,7 +90,7 @@ def rotated_rpn_head__get_bboxes(ctx, # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. scores = cls_score.softmax(-1)[..., 0] scores = scores.reshape(batch_size, -1, 1) - bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 6) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) # use static anchor if input shape is static if not is_dynamic_flag: @@ -129,13 +130,16 @@ def rotated_rpn_head__get_bboxes(ctx, post_params = get_post_processing_params(deploy_cfg) iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.pre_top_k keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) # only one class in rpn max_output_boxes_per_class = keep_top_k - return fake_multiclass_nms_rotated( + return multiclass_nms( batch_mlvl_bboxes, batch_mlvl_scores, max_output_boxes_per_class, iou_threshold=iou_threshold, - keep_top_k=keep_top_k, - version=self.version) + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/__init__.py b/mmdeploy/codebase/mmrotate/models/roi_heads/__init__.py new file mode 100644 index 0000000000..6a3b2ccef8 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .gv_bbox_head import gv_bbox_head__get_bboxes +from .gv_ratio_roi_head import gv_ratio_roi_head__simple_test_bboxes +from .oriented_standard_roi_head import \ + oriented_standard_roi_head__simple_test_bboxes +from .roi_extractors import rotated_single_roi_extractor__forward__tensorrt +from .rotated_bbox_head import rotated_bbox_head__get_bboxes + +__all__ = [ + 'gv_bbox_head__get_bboxes', 'gv_ratio_roi_head__simple_test_bboxes', + 'oriented_standard_roi_head__simple_test_bboxes', + 'rotated_single_roi_extractor__forward__tensorrt', + 'rotated_bbox_head__get_bboxes' +] diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py new file mode 100644 index 0000000000..528ebb67d0 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from mmrotate.core import hbb2obb + +from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmrotate.core.post_processing import \ + multiclass_nms_rotated +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmrotate.models.roi_heads.bbox_heads.GVBBoxHead.get_bboxes') +def gv_bbox_head__get_bboxes(ctx, + self, + rois, + cls_score, + bbox_pred, + fix_pred, + ratio_pred, + img_shape, + scale_factor, + rescale=False, + cfg=None): + """Transform network output for a batch into bbox predictions. + + Args: + rois (torch.Tensor): Boxes to be transformed. Has shape + (num_boxes, 6). last dimension 5 arrange as + (batch_index, x, y, w, h, theta). + cls_score (torch.Tensor): Box scores, has shape + (num_boxes, num_classes + 1). + bbox_pred (Tensor, optional): Box energies / deltas. + has shape (num_boxes, num_classes * 6). + img_shape (Sequence[int], optional): Maximum bounds for boxes, + specifies (H, W, C) or (H, W). + scale_factor (ndarray): Scale factor of the + image arrange as (w_scale, h_scale, w_scale, h_scale). + rescale (bool): If True, return boxes in original image space. + Default: False. + cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None + + Returns: + tuple[Tensor, Tensor]: + First tensor is `det_bboxes`, has the shape + (num_boxes, 6) and last + dimension 6 represent (cx, cy, w, h, theta, score). + Second tensor is the labels with shape (num_boxes, ). + """ + assert rois.ndim == 3, 'Only support export two stage ' \ + 'model to ONNX ' \ + 'with batch dimension. ' + + if self.custom_cls_channels: + scores = self.loss_cls.get_activation(cls_score) + else: + scores = F.softmax( + cls_score, dim=-1) if cls_score is not None else None + + assert bbox_pred is not None + bboxes = self.bbox_coder.decode( + rois[..., 1:], bbox_pred, max_shape=img_shape) + + rbboxes = self.fix_coder.decode(bboxes, fix_pred) + + bboxes = bboxes.view(*ratio_pred.size(), 4) + rbboxes = rbboxes.view(*ratio_pred.size(), 5) + rbboxes = torch.where(ratio_pred > self.ratio_thr, + hbb2obb(bboxes, self.version), rbboxes) + # ignore background class + scores = scores[..., :self.num_classes] + + post_params = get_post_processing_params(ctx.cfg) + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + + return multiclass_nms_rotated( + rbboxes, + scores, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_ratio_roi_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_ratio_roi_head.py new file mode 100644 index 0000000000..6582d3fbd3 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_ratio_roi_head.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmrotate.models.roi_heads.gv_ratio_roi_head' + '.GVRatioRoIHead.simple_test_bboxes') +def gv_ratio_roi_head__simple_test_bboxes(ctx, + self, + x, + img_metas, + proposals, + rcnn_test_cfg, + rescale=False): + """Test only det bboxes without augmentation. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + img_metas (list[dict]): Image meta info. + proposals (List[Tensor]): Region proposals. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Default: False. + + Returns: + tuple[list[Tensor], list[Tensor]]: The first list contains \ + the boxes of the corresponding image in a batch, each \ + tensor has the shape (num_boxes, 6) and last dimension \ + 6 represent (x, y, w, h, theta, score). Each Tensor \ + in the second list is the labels with shape (num_boxes, ). \ + The length of both lists should be equal to batch_size. + """ + + rois, labels = proposals + batch_index = torch.arange( + rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand( + rois.size(0), rois.size(1), 1) + rois = torch.cat([batch_index, rois[..., :4]], dim=-1) + batch_size = rois.shape[0] + num_proposals_per_img = rois.shape[1] + + # Eliminate the batch dimension + rois = rois.view(-1, 5) + bbox_results = self._bbox_forward(x, rois) + cls_score = bbox_results['cls_score'] + bbox_pred = bbox_results['bbox_pred'] + fix_pred = bbox_results['fix_pred'] + ratio_pred = bbox_results['ratio_pred'] + + # Recover the batch dimension + rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1)) + cls_score = cls_score.reshape(batch_size, num_proposals_per_img, + cls_score.size(-1)) + + bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, + bbox_pred.size(-1)) + fix_pred = fix_pred.reshape(batch_size, num_proposals_per_img, + fix_pred.size(-1)) + ratio_pred = ratio_pred.reshape(batch_size, num_proposals_per_img, + ratio_pred.size(-1)) + det_bboxes, det_labels = self.bbox_head.get_bboxes( + rois, + cls_score, + bbox_pred, + fix_pred, + ratio_pred, + img_metas[0]['img_shape'], + None, + rescale=rescale, + cfg=self.test_cfg) + return det_bboxes, det_labels diff --git a/mmdeploy/codebase/mmrotate/models/oriented_standard_roi_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/oriented_standard_roi_head.py similarity index 93% rename from mmdeploy/codebase/mmrotate/models/oriented_standard_roi_head.py rename to mmdeploy/codebase/mmrotate/models/roi_heads/oriented_standard_roi_head.py index f977e20d64..119ec11d6b 100644 --- a/mmdeploy/codebase/mmrotate/models/oriented_standard_roi_head.py +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/oriented_standard_roi_head.py @@ -5,10 +5,10 @@ @FUNCTION_REWRITER.register_rewriter( - 'mmrotate.models.roi_heads.oriented_standard_roi_head' - '.OrientedStandardRoIHead.simple_test') -def oriented_standard_roi_head__simple_test(ctx, self, x, proposals, img_metas, - **kwargs): + 'mmrotate.models.roi_heads.rotate_standard_roi_head' + '.RotatedStandardRoIHead.simple_test') +def rotate_standard_roi_head__simple_test(ctx, self, x, proposals, img_metas, + **kwargs): """Rewrite `simple_test` of `StandardRoIHead` for default backend. This function returns detection result as Tensor instead of numpy diff --git a/mmdeploy/codebase/mmrotate/models/roi_extractors.py b/mmdeploy/codebase/mmrotate/models/roi_heads/roi_extractors.py similarity index 100% rename from mmdeploy/codebase/mmrotate/models/roi_extractors.py rename to mmdeploy/codebase/mmrotate/models/roi_heads/roi_extractors.py diff --git a/mmdeploy/codebase/mmrotate/models/rotated_bbox_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/rotated_bbox_head.py similarity index 100% rename from mmdeploy/codebase/mmrotate/models/rotated_bbox_head.py rename to mmdeploy/codebase/mmrotate/models/roi_heads/rotated_bbox_head.py From 1f04b0e2f0a75a507536b5e039eeeaf48f47d90b Mon Sep 17 00:00:00 2001 From: grimoire Date: Sun, 26 Jun 2022 23:21:01 +0800 Subject: [PATCH 2/9] fix --- .../mmrotate/models/roi_heads/gv_bbox_head.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py index 528ebb67d0..94ee2b4481 100644 --- a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py @@ -1,7 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -import torch import torch.nn.functional as F -from mmrotate.core import hbb2obb from mmdeploy.codebase.mmdet import get_post_processing_params from mmdeploy.codebase.mmrotate.core.post_processing import \ @@ -65,8 +63,13 @@ def gv_bbox_head__get_bboxes(ctx, bboxes = bboxes.view(*ratio_pred.size(), 4) rbboxes = rbboxes.view(*ratio_pred.size(), 5) - rbboxes = torch.where(ratio_pred > self.ratio_thr, - hbb2obb(bboxes, self.version), rbboxes) + + # TODO: Find a way to fix the usage of ratio_pred + # from mmrotate.core import hbb2obb + # rbboxes = rbboxes.where( + # ratio_pred.unsqueeze(-1).expand_as(rbboxes) > self.ratio_thr, + # hbb2obb(bboxes, self.version)) + # ignore background class scores = scores[..., :self.num_classes] From 3afb38d00769c75c7fe3569cabff99dbae9afb1b Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 27 Jun 2022 10:56:59 +0800 Subject: [PATCH 3/9] add hbb2obb --- .../codebase/mmrotate/models/roi_heads/gv_bbox_head.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py index 94ee2b4481..aef1253909 100644 --- a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py @@ -64,11 +64,10 @@ def gv_bbox_head__get_bboxes(ctx, bboxes = bboxes.view(*ratio_pred.size(), 4) rbboxes = rbboxes.view(*ratio_pred.size(), 5) - # TODO: Find a way to fix the usage of ratio_pred - # from mmrotate.core import hbb2obb - # rbboxes = rbboxes.where( - # ratio_pred.unsqueeze(-1).expand_as(rbboxes) > self.ratio_thr, - # hbb2obb(bboxes, self.version)) + from mmrotate.core import hbb2obb + rbboxes = rbboxes.where( + ratio_pred.unsqueeze(-1) < self.ratio_thr, + hbb2obb(bboxes, self.version)) # ignore background class scores = scores[..., :self.num_classes] From 844b8fc483d8c2b3b5cb4f1a140cd3020f5bd0b9 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 27 Jun 2022 15:32:46 +0800 Subject: [PATCH 4/9] add ut --- .../mmrotate/models/roi_heads/gv_bbox_head.py | 1 + .../test_mmrotate/test_mmrotate_core.py | 34 ++- .../test_mmrotate/test_mmrotate_models.py | 219 ++++++++++++++++++ 3 files changed, 252 insertions(+), 2 deletions(-) diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py index aef1253909..af64440ef9 100644 --- a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py @@ -68,6 +68,7 @@ def gv_bbox_head__get_bboxes(ctx, rbboxes = rbboxes.where( ratio_pred.unsqueeze(-1) < self.ratio_thr, hbb2obb(bboxes, self.version)) + rbboxes = rbboxes.squeeze(2) # ignore background class scores = scores[..., :self.num_classes] diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py index e9c8d4936c..fb9e9081bf 100644 --- a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py @@ -6,8 +6,9 @@ from mmdeploy.codebase import import_codebase from mmdeploy.utils import Backend, Codebase -from mmdeploy.utils.test import (WrapFunction, backend_checker, check_backend, - get_onnx_model, get_rewrite_outputs) +from mmdeploy.utils.test import (WrapFunction, WrapModel, backend_checker, + check_backend, get_onnx_model, + get_rewrite_outputs) try: import_codebase(Codebase.MMROTATE) @@ -309,3 +310,32 @@ def poly2obb_le90(*args, **kwargs): run_with_backend=False) assert rewrite_outputs is not None + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_gvfixcoder__decode(backend_type: Backend): + check_backend(backend_type) + + deploy_cfg = mmcv.Config( + dict( + onnx_config=dict(output_names=['output'], input_shape=None), + backend_config=dict(type=backend_type.value), + codebase_config=dict(type='mmrotate', task='RotatedDetection'))) + + from mmrotate.core.bbox import GVFixCoder + coder = GVFixCoder(angle_range='le90') + + hbboxes = torch.rand(1, 10, 4) + fix_deltas = torch.rand(1, 10, 4) + + wrapped_model = WrapModel(coder, 'decode') + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model, + model_inputs={ + 'hbboxes': hbboxes, + 'fix_deltas': fix_deltas + }, + deploy_cfg=deploy_cfg, + run_with_backend=False) + + assert rewrite_outputs is not None diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py index 656a2b4e20..52c3610ada 100644 --- a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py @@ -332,3 +332,222 @@ def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend): model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) assert rewrite_outputs is not None + + +def get_rotated_rpn_head_model(): + """Oriented RPN Head Config.""" + test_cfg = mmcv.Config( + dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) + from mmrotate.models.dense_heads import RotatedRPNHead + model = RotatedRPNHead( + version='le90', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + test_cfg=test_cfg) + + model.requires_grad_(False) + return model + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_get_bboxes_of_rotated_rpn_head(backend_type: Backend): + check_backend(backend_type) + head = get_rotated_rpn_head_model() + head.cpu().eval() + s = 128 + img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + + output_names = ['dets', 'labels'] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.1, + pre_top_k=2000, + keep_top_k=2000)))) + + # the cls_score's size: (1, 3, 32, 32), (1, 3, 16, 16), + # (1, 3, 8, 8), (1, 3, 4, 4), (1, 3, 2, 2). + # the bboxes's size: (1, 18, 32, 32), (1, 18, 16, 16), + # (1, 18, 8, 8), (1, 18, 4, 4), (1, 18, 2, 2) + seed_everything(1234) + cls_score = [ + torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(5, 0, -1) + ] + seed_everything(5678) + bboxes = [torch.rand(1, 18, pow(2, i), pow(2, i)) for i in range(5, 0, -1)] + + # to get outputs of onnx model after rewrite + img_metas[0]['img_shape'] = torch.Tensor([s, s]) + wrapped_model = WrapModel( + head, 'get_bboxes', img_metas=img_metas, with_nms=True) + rewrite_inputs = { + 'cls_scores': cls_score, + 'bbox_preds': bboxes, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + assert rewrite_outputs is not None + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_rotate_standard_roi_head__simple_test(backend_type: Backend): + check_backend(backend_type) + from mmrotate.models.roi_heads import OrientedStandardRoIHead + output_names = ['dets', 'labels'] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.1, + pre_top_k=2000, + keep_top_k=2000)))) + angle_version = 'le90' + test_cfg = mmcv.Config( + dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) + head = OrientedStandardRoIHead( + bbox_roi_extractor=dict( + type='RotatedSingleRoIExtractor', + roi_layer=dict( + type='RoIAlignRotated', + out_size=7, + sample_num=2, + clockwise=True), + out_channels=3, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='RotatedShared2FCBBoxHead', + in_channels=3, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=15, + bbox_coder=dict( + type='DeltaXYWHAOBBoxCoder', + angle_range=angle_version, + norm_factor=None, + edge_swap=True, + proj_xy=True, + target_means=(.0, .0, .0, .0, .0), + target_stds=(0.1, 0.1, 0.2, 0.2, 0.1)), + reg_class_agnostic=True), + test_cfg=test_cfg) + head.cpu().eval() + + seed_everything(1234) + x = [torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(4, 0, -1)] + proposals = [torch.rand(1, 100, 6), torch.randint(0, 10, (1, 100))] + img_metas = [{'img_shape': torch.tensor([224, 224])}] + + wrapped_model = WrapModel( + head, 'simple_test', proposals=proposals, img_metas=img_metas) + rewrite_inputs = {'x': x} + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + assert rewrite_outputs is not None + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_gv_ratio_roi_head__simple_test(backend_type: Backend): + check_backend(backend_type) + from mmrotate.models.roi_heads import GVRatioRoIHead + output_names = ['dets', 'labels'] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.1, + pre_top_k=2000, + keep_top_k=2000)))) + angle_version = 'le90' + test_cfg = mmcv.Config( + dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) + head = GVRatioRoIHead( + version=angle_version, + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=3, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='GVBBoxHead', + version=angle_version, + num_shared_fcs=2, + in_channels=3, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=15, + ratio_thr=0.8, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=(.0, .0, .0, .0), + target_stds=(0.1, 0.1, 0.2, 0.2)), + fix_coder=dict(type='GVFixCoder', angle_range=angle_version), + ratio_coder=dict(type='GVRatioCoder', angle_range=angle_version), + reg_class_agnostic=True), + test_cfg=test_cfg) + head.cpu().eval() + + seed_everything(1234) + x = [torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(4, 0, -1)] + bboxes = torch.rand(1, 100, 2) + bboxes = torch.cat( + [bboxes, bboxes + torch.rand(1, 100, 2) + torch.rand(1, 100, 1)], + dim=-1) + proposals = [bboxes, torch.randint(0, 10, (1, 100))] + img_metas = [{'img_shape': torch.tensor([224, 224])}] + + wrapped_model = WrapModel( + head, 'simple_test', proposals=proposals, img_metas=img_metas) + rewrite_inputs = {'x': x} + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + assert rewrite_outputs is not None From f0b665acb090e3635921f42f2583914d458b9705 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 28 Jun 2022 12:39:51 +0800 Subject: [PATCH 5/9] fix rotated nms --- configs/mmrotate/rotated-detection_static.py | 3 +- .../common_impl/nms/allClassRotatedNMS.cu | 2 +- .../mmrotate/core/post_processing/bbox_nms.py | 1 + .../deploy/rotated_detection_model.py | 28 +++++++++++++++++++ .../mmrotate/models/roi_heads/gv_bbox_head.py | 2 ++ 5 files changed, 34 insertions(+), 2 deletions(-) diff --git a/configs/mmrotate/rotated-detection_static.py b/configs/mmrotate/rotated-detection_static.py index 324de6f7f7..b696260e26 100644 --- a/configs/mmrotate/rotated-detection_static.py +++ b/configs/mmrotate/rotated-detection_static.py @@ -6,4 +6,5 @@ score_threshold=0.05, iou_threshold=0.1, pre_top_k=3000, - keep_top_k=2000)) + keep_top_k=2000, + max_output_boxes_per_class=2000)) diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu index 8d3858deae..0edea2bfaf 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu @@ -295,7 +295,7 @@ __host__ __device__ __forceinline__ T single_box_iou_rotated(T const *const box1 const T area1 = box1.w * box1.h; const T area2 = box2.w * box2.h; if (area1 < 1e-14 || area2 < 1e-14) { - return 0.f; + return 1.0f; } const T intersection = rotated_boxes_intersection(box1, box2); diff --git a/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py b/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py index 336f49aa3c..4a7b8375be 100644 --- a/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py +++ b/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py @@ -77,6 +77,7 @@ def select_rnms_index(scores: torch.Tensor, def _multiclass_nms_rotated(boxes: Tensor, scores: Tensor, + max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.1, score_threshold: float = 0.05, pre_top_k: int = -1, diff --git a/mmdeploy/codebase/mmrotate/deploy/rotated_detection_model.py b/mmdeploy/codebase/mmrotate/deploy/rotated_detection_model.py index dd73665f4f..f864e17d9b 100644 --- a/mmdeploy/codebase/mmrotate/deploy/rotated_detection_model.py +++ b/mmdeploy/codebase/mmrotate/deploy/rotated_detection_model.py @@ -75,6 +75,33 @@ def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], output_names=output_names, deploy_cfg=self.deploy_cfg) + @staticmethod + def __clear_outputs( + test_outputs: List[Union[torch.Tensor, np.ndarray]] + ) -> List[Union[List[torch.Tensor], List[np.ndarray]]]: + """Removes additional outputs and detections with zero and negative + score. + + Args: + test_outputs (List[Union[torch.Tensor, np.ndarray]]): + outputs of forward_test. + + Returns: + List[Union[List[torch.Tensor], List[np.ndarray]]]: + outputs with without zero score object. + """ + batch_size = len(test_outputs[0]) + + num_outputs = len(test_outputs) + outputs = [[None for _ in range(batch_size)] + for _ in range(num_outputs)] + + for i in range(batch_size): + inds = test_outputs[0][i, :, -1] > 0.0 + for output_id in range(num_outputs): + outputs[output_id][i] = test_outputs[output_id][i, inds, ...] + return outputs + def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[Sequence[dict]], *args, **kwargs) -> list: """Run forward inference. @@ -91,6 +118,7 @@ def forward(self, img: Sequence[torch.Tensor], input_img = img[0].contiguous() img_metas = img_metas[0] outputs = self.forward_test(input_img, img_metas, *args, **kwargs) + outputs = End2EndModel.__clear_outputs(outputs) batch_dets, batch_labels = outputs[:2] batch_size = input_img.shape[0] rescale = kwargs.get('rescale', False) diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py index af64440ef9..3d777218d8 100644 --- a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py @@ -74,6 +74,7 @@ def gv_bbox_head__get_bboxes(ctx, scores = scores[..., :self.num_classes] post_params = get_post_processing_params(ctx.cfg) + max_output_boxes_per_class = post_params.max_output_boxes_per_class iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) score_threshold = cfg.get('score_thr', post_params.score_threshold) pre_top_k = post_params.pre_top_k @@ -82,6 +83,7 @@ def gv_bbox_head__get_bboxes(ctx, return multiclass_nms_rotated( rbboxes, scores, + max_output_boxes_per_class, iou_threshold=iou_threshold, score_threshold=score_threshold, pre_top_k=pre_top_k, From 8c86052db3119e1c7ac9340a297cd106102ddb33 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 29 Jun 2022 15:38:20 +0800 Subject: [PATCH 6/9] update docs --- docs/en/03-benchmark/supported_models.md | 2 ++ docs/en/04-supported-codebases/mmrotate.md | 1 + docs/zh_cn/03-benchmark/supported_models.md | 1 + 3 files changed, 4 insertions(+) diff --git a/docs/en/03-benchmark/supported_models.md b/docs/en/03-benchmark/supported_models.md index 0cba5731eb..245647efa1 100644 --- a/docs/en/03-benchmark/supported_models.md +++ b/docs/en/03-benchmark/supported_models.md @@ -72,6 +72,8 @@ The table below lists the models that are guaranteed to be exportable to other b | PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) | | CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) | | RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | +| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | +| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) | ### Note diff --git a/docs/en/04-supported-codebases/mmrotate.md b/docs/en/04-supported-codebases/mmrotate.md index 1b9e403ae1..5594cc7529 100644 --- a/docs/en/04-supported-codebases/mmrotate.md +++ b/docs/en/04-supported-codebases/mmrotate.md @@ -12,6 +12,7 @@ Please refer to [official installation guide](https://mmrotate.readthedocs.io/en | :--------------- | :--------------- | :----------: | :------: | :--: | :---: | :------: | :--------------------------------------------------------------------------------------------: | | RotatedRetinaNet | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | | Oriented RCNN | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | +| Gliding Vertex | RotatedDetection | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) | ### Example diff --git a/docs/zh_cn/03-benchmark/supported_models.md b/docs/zh_cn/03-benchmark/supported_models.md index 2a5f24b2c3..9324fa75fb 100644 --- a/docs/zh_cn/03-benchmark/supported_models.md +++ b/docs/zh_cn/03-benchmark/supported_models.md @@ -71,6 +71,7 @@ | CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) | | RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | | Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | +| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) | ## Note From f08551bb0be412e3dd09906340c439e3a9394850 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 1 Jul 2022 12:28:26 +0800 Subject: [PATCH 7/9] update benchmark --- docs/en/03-benchmark/benchmark.md | 12 ++++++++++++ docs/zh_cn/03-benchmark/benchmark.md | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/docs/en/03-benchmark/benchmark.md b/docs/en/03-benchmark/benchmark.md index 31bf9b9ca1..2bb444415b 100644 --- a/docs/en/03-benchmark/benchmark.md +++ b/docs/en/03-benchmark/benchmark.md @@ -1618,6 +1618,18 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - + + GlidingVertex + Rotated Detection + DOTA-v1.0 + mAP + 0.732 + - + 0.733 + 0.731 + - + - + diff --git a/docs/zh_cn/03-benchmark/benchmark.md b/docs/zh_cn/03-benchmark/benchmark.md index 7bbaf98db8..78465aab94 100644 --- a/docs/zh_cn/03-benchmark/benchmark.md +++ b/docs/zh_cn/03-benchmark/benchmark.md @@ -1615,6 +1615,18 @@ GPU: ncnn, TensorRT, PPLNN - - + + GlidingVertex + Rotated Detection + DOTA-v1.0 + mAP + 0.732 + - + 0.733 + 0.731 + - + - + From 32f37c937763461582343d6237fde20c44af3679 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 6 Jul 2022 19:40:29 +0800 Subject: [PATCH 8/9] update test --- tests/regression/mmrotate.yml | 18 ++++++++++++++++++ .../test_mmrotate/test_mmrotate_models.py | 3 ++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/regression/mmrotate.yml b/tests/regression/mmrotate.yml index 45630304e5..dac16838d1 100644 --- a/tests/regression/mmrotate.yml +++ b/tests/regression/mmrotate.yml @@ -48,3 +48,21 @@ models: - *pipeline_ort_detection_dynamic_fp32 - *pipeline_trt_detection_dynamic_fp32 - *pipeline_trt_detection_dynamic_fp16 + + - name: oriented_rcnn + metafile: configs/oriented_rcnn/metafile.yml + model_configs: + - configs/oriented_rcnn/oriented_rcnn_r50_fpn_fp16_1x_dota_le90.py + pipelines: + - *pipeline_ort_detection_dynamic_fp32 + - *pipeline_trt_detection_dynamic_fp32 + - *pipeline_trt_detection_dynamic_fp16 + + - name: gliding_vertex + metafile: configs/gliding_vertex/metafile.yml + model_configs: + - configs/gliding_vertex/gliding_vertex_r50_fpn_1x_dota_le90.py + pipelines: + - *pipeline_ort_detection_dynamic_fp32 + - *pipeline_trt_detection_dynamic_fp32 + - *pipeline_trt_detection_dynamic_fp16 diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py index 52c3610ada..1e2fea4a64 100644 --- a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py @@ -499,7 +499,8 @@ def test_gv_ratio_roi_head__simple_test(backend_type: Backend): score_threshold=0.05, iou_threshold=0.1, pre_top_k=2000, - keep_top_k=2000)))) + keep_top_k=2000, + max_output_boxes_per_class=1000)))) angle_version = 'le90' test_cfg = mmcv.Config( dict( From 3ead3716600f1643d8bc51fc94c2d8900c7630b8 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 11 Jul 2022 10:11:52 +0800 Subject: [PATCH 9/9] remove ort regression test, remove comment --- .../codebase/mmrotate/models/dense_heads/rotated_rpn_head.py | 2 -- tests/regression/mmrotate.yml | 1 - 2 files changed, 3 deletions(-) diff --git a/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py index d7c8524be8..586bec8226 100644 --- a/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py +++ b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py @@ -3,8 +3,6 @@ from mmdeploy.codebase.mmdet import (get_post_processing_params, pad_with_value_if_necessary) -# from mmdeploy.codebase.mmrotate.core.post_processing import \ -# multiclass_nms_rotated from mmdeploy.codebase.mmdet.core.post_processing import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.utils import is_dynamic_shape diff --git a/tests/regression/mmrotate.yml b/tests/regression/mmrotate.yml index dac16838d1..af75daf838 100644 --- a/tests/regression/mmrotate.yml +++ b/tests/regression/mmrotate.yml @@ -63,6 +63,5 @@ models: model_configs: - configs/gliding_vertex/gliding_vertex_r50_fpn_1x_dota_le90.py pipelines: - - *pipeline_ort_detection_dynamic_fp32 - *pipeline_trt_detection_dynamic_fp32 - *pipeline_trt_detection_dynamic_fp16