Skip to content

Commit

Permalink
support det wip
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored and irexyc committed Jul 27, 2022
1 parent 685d448 commit 5c42562
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 32 deletions.
12 changes: 3 additions & 9 deletions mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,26 +300,20 @@ def multiclass_nms__coreml(ctx,
batch_size = scores.shape[0]
assert batch_size == 1, 'batched nms is not supported for now.'

# box_per_cls = len(boxes.shape) == 4

# pre-topk
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_top_k)
boxes = boxes[:, topk_inds.squeeze(), ...]
scores = scores[:, topk_inds.squeeze(), ...]

boxes, scores, _, _ = coreml_nms(boxes, scores, iou_threshold,
score_threshold,
max_output_boxes_per_class)
boxes, scores, _, _ = coreml_nms(
boxes, scores, iou_threshold, score_threshold,
min(keep_top_k, max_output_boxes_per_class))

scores, labels = scores.max(-1)
dets = torch.cat([boxes, scores.unsqueeze(-1)], dim=-1)

if keep_top_k > 0:
dets = dets[:, :keep_top_k, ...]
labels = labels[:, :keep_top_k]

return dets, labels


Expand Down
156 changes: 156 additions & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,162 @@ def base_dense_head__get_bbox(ctx,
keep_top_k=keep_top_k)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.base_dense_head.'
'BaseDenseHead.get_bboxes',
backend='coreml')
def base_dense_head__get_bbox__coreml(ctx,
self,
cls_scores,
bbox_preds,
score_factors=None,
img_metas=None,
cfg=None,
rescale=False,
with_nms=True,
**kwargs):
"""Rewrite `get_bboxes` of `BaseDenseHead` for CoreML backend.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
score_factors (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Default None.
img_metas (list[dict], Optional): Image meta info. Default None.
cfg (mmcv.Config, Optional): Test / postprocessing configuration,
if None, test_cfg would be used. Default None.
rescale (bool): If True, return boxes in original image space.
Default False.
with_nms (bool): If True, do nms before return boxes.
Default True.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)

featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device)
mlvl_priors = [priors.unsqueeze(0) for priors in mlvl_priors]

mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
if score_factors is None:
with_score_factors = False
mlvl_score_factor = [None for _ in range(num_levels)]
else:
with_score_factors = True
mlvl_score_factor = [
score_factors[i].detach() for i in range(num_levels)
]
mlvl_score_factors = []
assert img_metas is not None
img_shape = img_metas[0]['img_shape']
assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
batch_size = cls_scores[0].shape[0]
assert batch_size == 1, \
'coreml only support detection model with batch_size==1'
cfg = self.test_cfg
pre_topk = cfg.get('nms_pre', -1)

mlvl_valid_bboxes = []
mlvl_valid_scores = []
mlvl_valid_priors = []

for cls_score, bbox_pred, score_factors, priors in zip(
mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor, mlvl_priors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]

scores = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.cls_out_channels)
if self.use_sigmoid_cls:
scores = scores.sigmoid()
else:
scores = scores.softmax(-1)
if with_score_factors:
score_factors = score_factors.permute(0, 2, 3,
1).reshape(batch_size,
-1).sigmoid()
score_factors = score_factors.unsqueeze(2)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
if not is_dynamic_flag:
priors = priors.data
if pre_topk > 0:
nms_pre_score = scores
if with_score_factors:
nms_pre_score = nms_pre_score * score_factors

# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = nms_pre_score.max(-1)
else:
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
topk_inds = topk_inds.squeeze(0)
priors = priors[:, topk_inds, :]
bbox_pred = bbox_pred[:, topk_inds, :]
scores = scores[:, topk_inds, :]
if with_score_factors:
score_factors = score_factors[:, topk_inds, :]

mlvl_valid_bboxes.append(bbox_pred)
mlvl_valid_scores.append(scores)
mlvl_valid_priors.append(priors)
if with_score_factors:
mlvl_score_factors.append(score_factors)

batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1)
batch_scores = torch.cat(mlvl_valid_scores, dim=1)
batch_priors = torch.cat(mlvl_valid_priors, dim=1)
batch_bboxes = self.bbox_coder.decode(
batch_priors, batch_mlvl_bboxes_pred, max_shape=img_shape)
if with_score_factors:
batch_score_factors = torch.cat(mlvl_score_factors, dim=1)

if not self.use_sigmoid_cls:
batch_scores = batch_scores[..., :self.num_classes]

if with_score_factors:
batch_scores = batch_scores * batch_score_factors

if not with_nms:
return batch_bboxes, batch_scores

post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms(
batch_bboxes,
batch_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.base_dense_head.BaseDenseHead'
'.get_bboxes',
Expand Down
23 changes: 0 additions & 23 deletions mmdeploy/pytorch/functions/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,3 @@ def topk__tensorrt(ctx,
k = MAX_TOPK_K

return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)


@FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='coreml')
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.topk', backend='coreml')
def topk__coreml(ctx,
input: torch.Tensor,
k: int,
dim: Optional[int] = None,
largest: bool = True,
sorted: bool = True):
"""Rewrite `topk` for CoreML backend.
Replace topk with sort + slice.
"""
assert sorted
dim = -1 if dim is None else dim
dim = dim if dim > 0 else input.dim() + dim
value, index = input.sort(dim=dim, descending=largest)
k = min(k, input.size(dim))
slices = [slice(None)] * (dim + 1)
slices[dim] = slice(k)
return value[slices], index[slices]

0 comments on commit 5c42562

Please sign in to comment.