Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support MMRotate model with le135 #788

Merged
merged 4 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/jit/passes/dead_code_elimination.h>

#include "../../ir/subgraph_matcher.h"
#include "common_subgraph_elimination.h"
#include "torch/csrc/jit/ir/irparser.h"

namespace mmdeploy {
Expand Down Expand Up @@ -126,14 +127,16 @@ void FuseSelectAssign(Block* block, std::unordered_map<std::string, Tensor>& par

void FuseSelectAssign(std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, Tensor>& params) {
// cse before search
CommonSubgraphElimination(graph, params);

std::string pattern_str = R"IR(
graph(%y, %z, %cmp_1, %cmp_2, %start, %axes):
graph(%y, %z, %cmp_1, %cmp_2, %start, %axes, %shape_2):
%nz_1 = onnx::NonZero(%cmp_1)
%trans_1 = onnx::Transpose(%nz_1)
%gather_1 = onnx::GatherND(%z, %trans_1)
%reshape_1_shape = onnx::Constant()
%reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape)
%shape_2 = onnx::Shape(%y)
%expand_2 = onnx::Expand(%cmp_2, %shape_2)
%nz_2 = onnx::NonZero(%expand_2)
%trans_2 = onnx::Transpose(%nz_2)
Expand Down
62 changes: 62 additions & 0 deletions mmdeploy/codebase/mmrotate/core/bbox/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,65 @@ def poly2obb_le90__tensorrt(ctx, polys: torch.Tensor) -> torch.Tensor:
width, _ = torch.max(edges, 1)
height, _ = torch.min(edges, 1)
return torch.stack([x_ctr, y_ctr, width, height, angles], 1)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmrotate.core.bbox.transforms.poly2obb_le135')
def poly2obb_le135__default(ctx, polys):
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
"""This is a rewrite for poly2obb to remove NonZero ops.

Args:
polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3]

Returns:
obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]
"""
polys = torch.reshape(polys, [-1, 8])
pt1, pt2, pt3, pt4 = polys[..., :8].chunk(4, 1)
edge1 = torch.sqrt(
torch.pow(pt1[..., 0] - pt2[..., 0], 2) +
torch.pow(pt1[..., 1] - pt2[..., 1], 2))
edge2 = torch.sqrt(
torch.pow(pt2[..., 0] - pt3[..., 0], 2) +
torch.pow(pt2[..., 1] - pt3[..., 1], 2))
angles1 = torch.atan2((pt2[..., 1] - pt1[..., 1]),
(pt2[..., 0] - pt1[..., 0]))
angles2 = torch.atan2((pt4[..., 1] - pt1[..., 1]),
(pt4[..., 0] - pt1[..., 0]))
angles = torch.where(edge1 > edge2, angles1, angles2)
angles = norm_angle(angles, 'le135')
x_ctr = (pt1[..., 0] + pt3[..., 0]) / 2.0
y_ctr = (pt1[..., 1] + pt3[..., 1]) / 2.0
edges = torch.stack([edge1, edge2], dim=1)
width, _ = torch.max(edges, 1)
height, _ = torch.min(edges, 1)
return torch.stack([x_ctr, y_ctr, width, height, angles], 1)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmrotate.core.bbox.transforms.obb2poly_le135')
def obb2poly_le135__default(ctx, rboxes):
"""Support batched input.

Args:
ctx : context of rewriter
obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]

Returns:
polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3]
"""
B, N = rboxes.shape[:2]
x_ctr, y_ctr, width, height, angle = rboxes[..., 0], rboxes[
..., 1], rboxes[..., 2], rboxes[..., 3], rboxes[..., 4]
tl_x, tl_y, br_x, br_y = \
-width * 0.5, -height * 0.5, \
width * 0.5, height * 0.5
rects = torch.stack([tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y],
dim=-1).reshape(B, N, 2, 4)
sin, cos = torch.sin(angle), torch.cos(angle)
M = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(B, N, 2, 2)
polys = M.matmul(rects).permute(0, 1, 3, 2)
xy_ctr = torch.stack([x_ctr, y_ctr], dim=-1).unsqueeze(-2)
polys += xy_ctr
polys = polys.reshape(B, N, 8)
return polys.contiguous()
70 changes: 70 additions & 0 deletions tests/test_codebase/test_mmrotate/test_mmrotate_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,76 @@ def poly2obb_le90(*args, **kwargs):
assert rewrite_outputs is not None


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_poly2obb_le135(backend_type: Backend):
check_backend(backend_type)
polys = torch.rand(1, 10, 8)
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(output_names=None, input_shape=None),
backend_config=dict(
type=backend_type.value,
model_inputs=[
dict(
input_shapes=dict(
polys=dict(
min_shape=polys.shape,
opt_shape=polys.shape,
max_shape=polys.shape)))
]),
codebase_config=dict(type='mmrotate', task='RotatedDetection')))

# wrap function to enable rewrite
def poly2obb_le135(*args, **kwargs):
import mmrotate
return mmrotate.core.bbox.transforms.poly2obb_le135(*args, **kwargs)

# wrap function to nn.Module, enable torch.onnx.export
wrapped_func = WrapFunction(poly2obb_le135)
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_func,
model_inputs={'polys': polys},
deploy_cfg=deploy_cfg,
run_with_backend=False)

assert rewrite_outputs is not None


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_obb2poly_le135(backend_type: Backend):
check_backend(backend_type)
rboxes = torch.rand(1, 10, 5)
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(output_names=None, input_shape=None),
backend_config=dict(
type=backend_type.value,
model_inputs=[
dict(
input_shapes=dict(
rboxes=dict(
min_shape=rboxes.shape,
opt_shape=rboxes.shape,
max_shape=rboxes.shape)))
]),
codebase_config=dict(type='mmrotate', task='RotatedDetection')))

# wrap function to enable rewrite
def obb2poly_le135(*args, **kwargs):
import mmrotate
return mmrotate.core.bbox.transforms.obb2poly_le135(*args, **kwargs)

# wrap function to nn.Module, enable torch.onnx.export
wrapped_func = WrapFunction(obb2poly_le135)
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_func,
model_inputs={'rboxes': rboxes},
deploy_cfg=deploy_cfg,
run_with_backend=False)

assert rewrite_outputs is not None


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_gvfixcoder__decode(backend_type: Backend):
check_backend(backend_type)
Expand Down