Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG P0] Fix CoreML Two Stage detector #1044

Merged
merged 1 commit into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions mmdeploy/backend/coreml/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,66 @@ def coreml_nms(context, node):
max_boxes=max_boxes)

context.add(tuple(results), torch_name=node.outputs[0])


@register_torch_op
def log2(context, node):
"""bind log2."""
import numpy as np
inputs = _get_inputs(context, node)
x = inputs[0]
log_x = mb.log(x=x)
context.add(mb.mul(x=log_x, y=1 / np.log(2.0)), node.name)


@register_torch_op
def roi_align(context, node):
"""roi align."""
inputs = _get_inputs(context, node)

x = context[node.inputs[0]]
input_shape = x.shape # (B, C, h_in, w_in)
if len(input_shape) != 4:
raise ValueError(
'"CropResize" op: expected input rank 4, got {}'.format(x.rank))

const_box_info = True
if context[node.inputs[1]].val is None or context[
node.inputs[2]].val is None:
const_box_info = False

extrapolation_value = context[node.inputs[2]].val
# CoreML index information along with boxes
if const_box_info:
boxes = context[node.inputs[1]].val
# CoreML expects boxes/ROI in
# [N, 1, 5, 1, 1] format
boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
else:
boxes = inputs[1]
boxes = mb.reshape(
x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
# Get Height and Width of crop
h_out = inputs[3]
w_out = inputs[4]

# Torch input format: [B, C, h_in, w_in]
# CoreML input format: [B, C, h_in, w_in]

# Crop Resize
x = mb.crop_resize(
x=x,
roi=boxes,
target_height=h_out.val,
target_width=w_out.val,
normalized_coordinates=False,
spatial_scale=extrapolation_value,
box_coordinate_mode='CORNERS_WIDTH_FIRST',
sampling_mode='OFFSET_CORNERS',
)

# CoreML output format: [N, 1, C, h_out, w_out]
# Torch output format: [N, C, h_out, w_out]
x = mb.squeeze(x=x, axes=[1])

context.add(x, torch_name=node.outputs[0])
18 changes: 12 additions & 6 deletions mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdeploy.codebase.mmdet import (get_post_processing_params,
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.core import FUNCTION_REWRITER
Expand Down Expand Up @@ -104,11 +104,17 @@ def rpn_head__get_bboxes(ctx,

if pre_topk > 0:
_, topk_inds = scores.squeeze(2).topk(pre_topk)
batch_inds = torch.arange(batch_size, device=device).unsqueeze(-1)
prior_inds = topk_inds.new_zeros((1, 1))
anchors = anchors[prior_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
bbox_pred, scores = gather_topk(
bbox_pred,
scores,
inds=topk_inds,
batch_size=batch_size,
is_batched=True)
anchors = gather_topk(
anchors,
inds=topk_inds,
batch_size=batch_size,
is_batched=False)
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_valid_anchors.append(anchors)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,39 @@ def single_roi_extractor__forward__openvino(ctx,
args = (output_size, featmap_strides, sample_num, rois, *feats)
result = SingleRoIExtractorOpenVINO.apply(*args)
return result


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
backend=Backend.COREML.value)
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
def single_roi_extractor__forward__coreml(ctx,
self,
feats,
rois,
roi_scale_factor=None):
"""Rewrite `forward` of SingleRoIExtractor for coreml."""
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size)
if num_levels == 1:
assert len(rois) > 0, 'The number of rois should be positive'
self.roi_layers[0].use_torchvision = True
return self.roi_layers[0](feats[0], rois)

target_lvls = self.map_roi_levels(rois, num_levels)

if roi_scale_factor is not None:
rois = self.roi_rescale(rois, roi_scale_factor)

for i in range(num_levels):
mask = target_lvls == i
# inds = mask.nonzero(as_tuple=False).squeeze(1)
rois_t = rois * mask.unsqueeze(-1)
# use the roi align in torhcvision
self.roi_layers[i].use_torchvision = True
roi_feats_t = self.roi_layers[i](feats[i], rois_t)
roi_feats = roi_feats + roi_feats_t * (rois_t[:, -1] > 0).reshape(
-1, 1, 1, 1)
# slice to recover original size
return roi_feats