From d9911263730db25135023fb9a209b23c1842d794 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Wed, 3 Aug 2022 15:47:21 +0800 Subject: [PATCH 1/4] fix adaptive_avg_pool exporting to onnx --- mmdeploy/codebase/mmpose/models/__init__.py | 1 - .../mmpose/models/backbones/__init__.py | 5 -- .../mmpose/models/backbones/litehrnet.py | 29 ------ .../mmseg/models/decode_heads/__init__.py | 3 +- .../mmseg/models/decode_heads/psp_head.py | 52 ----------- mmdeploy/pytorch/functions/__init__.py | 3 +- mmdeploy/pytorch/functions/adaptive_pool.py | 38 ++++++++ mmdeploy/pytorch/ops/__init__.py | 15 ++-- mmdeploy/pytorch/ops/adaptive_avg_pool.py | 90 ------------------- mmdeploy/pytorch/ops/adaptive_pool.py | 13 +++ tests/test_pytorch/test_pytorch_functions.py | 18 ++++ tests/test_pytorch/test_pytorch_ops.py | 49 ++++------ 12 files changed, 96 insertions(+), 220 deletions(-) delete mode 100644 mmdeploy/codebase/mmpose/models/backbones/__init__.py delete mode 100644 mmdeploy/codebase/mmpose/models/backbones/litehrnet.py delete mode 100644 mmdeploy/codebase/mmseg/models/decode_heads/psp_head.py create mode 100644 mmdeploy/pytorch/functions/adaptive_pool.py delete mode 100644 mmdeploy/pytorch/ops/adaptive_avg_pool.py create mode 100644 mmdeploy/pytorch/ops/adaptive_pool.py diff --git a/mmdeploy/codebase/mmpose/models/__init__.py b/mmdeploy/codebase/mmpose/models/__init__.py index d1fdb9eb44..0c859a49c7 100644 --- a/mmdeploy/codebase/mmpose/models/__init__.py +++ b/mmdeploy/codebase/mmpose/models/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .backbones import * # noqa: F401,F403 from .detectors import * # noqa: F401,F403 from .heads import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmpose/models/backbones/__init__.py b/mmdeploy/codebase/mmpose/models/backbones/__init__.py deleted file mode 100644 index 9309949c52..0000000000 --- a/mmdeploy/codebase/mmpose/models/backbones/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -from .litehrnet import cross_resolution_weighting__forward - -__all__ = ['cross_resolution_weighting__forward'] diff --git a/mmdeploy/codebase/mmpose/models/backbones/litehrnet.py b/mmdeploy/codebase/mmpose/models/backbones/litehrnet.py deleted file mode 100644 index 609eadaef6..0000000000 --- a/mmdeploy/codebase/mmpose/models/backbones/litehrnet.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn.functional as F - -from mmdeploy.core import FUNCTION_REWRITER - - -@FUNCTION_REWRITER.register_rewriter( - 'mmpose.models.backbones.litehrnet.CrossResolutionWeighting.forward') -def cross_resolution_weighting__forward(ctx, self, x): - """Rewrite ``forward`` for default backend. - - Rewrite this function to support export ``adaptive_avg_pool2d``. - - Args: - x (list): block input. - """ - - mini_size = [int(_) for _ in x[-1].shape[-2:]] - out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]] - out = torch.cat(out, dim=1) - out = self.conv1(out) - out = self.conv2(out) - out = torch.split(out, self.channels, dim=1) - out = [ - s * F.interpolate(a, size=s.size()[-2:], mode='nearest') - for s, a in zip(x, out) - ] - return out diff --git a/mmdeploy/codebase/mmseg/models/decode_heads/__init__.py b/mmdeploy/codebase/mmseg/models/decode_heads/__init__.py index e893f20460..ca8ec24135 100644 --- a/mmdeploy/codebase/mmseg/models/decode_heads/__init__.py +++ b/mmdeploy/codebase/mmseg/models/decode_heads/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .aspp_head import aspp_head__forward from .ema_head import ema_module__forward -from .psp_head import ppm__forward -__all__ = ['aspp_head__forward', 'ppm__forward', 'ema_module__forward'] +__all__ = ['aspp_head__forward', 'ema_module__forward'] diff --git a/mmdeploy/codebase/mmseg/models/decode_heads/psp_head.py b/mmdeploy/codebase/mmseg/models/decode_heads/psp_head.py deleted file mode 100644 index 210e6c7ad5..0000000000 --- a/mmdeploy/codebase/mmseg/models/decode_heads/psp_head.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import torch.nn as nn -from mmseg.ops import resize - -from mmdeploy.core import FUNCTION_REWRITER -from mmdeploy.utils import IR, get_root_logger, is_dynamic_shape - - -@FUNCTION_REWRITER.register_rewriter( - func_name='mmseg.models.decode_heads.psp_head.PPM.forward', ir=IR.ONNX) -def ppm__forward(ctx, self, x): - """Rewrite `forward` for default backend. - - Support configured dynamic/static shape in resize op. - - Args: - ctx (ContextCaller): The context with additional information. - self: The instance of the original class. - x (Tensor): The transformed input feature. - - Returns: - List[torch.Tensor]: Up-sampled segmentation maps of different - scales. - """ - deploy_cfg = ctx.cfg - is_dynamic_flag = is_dynamic_shape(deploy_cfg) - # get origin input shape as tensor to support onnx dynamic shape - size = x.shape[2:] - if not is_dynamic_flag: - size = [int(val) for val in size] - - ppm_outs = [] - for ppm in self: - if isinstance(ppm[0], nn.AdaptiveAvgPool2d) and \ - ppm[0].output_size != 1: - if is_dynamic_flag: - logger = get_root_logger() - logger.warning('`AdaptiveAvgPool2d` would be ' - 'replaced to `AvgPool2d` explicitly') - # replace AdaptiveAvgPool2d with AvgPool2d explicitly - output_size = 2 * [ppm[0].output_size] - k = [int(size[i] / output_size[i]) for i in range(0, len(size))] - ppm[0] = nn.AvgPool2d(k, stride=k, padding=0, ceil_mode=False) - ppm_out = ppm(x) - upsampled_ppm_out = resize( - ppm_out, - size=size, - mode='bilinear', - align_corners=self.align_corners) - ppm_outs.append(upsampled_ppm_out) - return ppm_outs diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index 4201942476..b79304e595 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .adaptive_pool import adaptive_avg_pool2d__default from .atan2 import atan2__default from .chunk import chunk__ncnn, chunk__torchscript from .expand import expand__ncnn @@ -20,5 +21,5 @@ 'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn', 'triu__default', 'atan2__default', 'normalize__ncnn', 'expand__ncnn', 'chunk__torchscript', 'masked_fill__onnxruntime', - 'tensor__setitem__default' + 'tensor__setitem__default', 'adaptive_avg_pool2d__default' ] diff --git a/mmdeploy/pytorch/functions/adaptive_pool.py b/mmdeploy/pytorch/functions/adaptive_pool.py new file mode 100644 index 0000000000..977348cfc3 --- /dev/null +++ b/mmdeploy/pytorch/functions/adaptive_pool.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch.nn.functional as F +from torch.nn.modules.utils import _pair + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import (Backend, get_backend, get_root_logger, + is_dynamic_shape) + + +@FUNCTION_REWRITER.register_rewriter( + func_name='torch.nn.functional.adaptive_avg_pool2d') +def adaptive_avg_pool2d__default(ctx, input, output_size): + """Rewrite `adaptive_avg_pool2d` for default backend.""" + supported_backends = [Backend.TORCHSCRIPT, Backend.NCNN] + if get_backend(ctx.cfg) in supported_backends: + return ctx.origin_func(input, output_size) + + output_size = _pair(output_size) + if int(output_size[0]) == int(output_size[1]) == 1: + out = ctx.origin_func(input, output_size) + else: + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + if is_dynamic_flag: + logger = get_root_logger() + logger.warning('`adaptive_avg_pool2d` would be ' + 'replaced to `avg_pool2d` explicitly') + size = input.shape[2:] + k = [int(size[i] / output_size[i]) for i in range(0, len(size))] + out = F.avg_pool2d( + input, + kernel_size=k, + stride=k, + padding=0, + ceil_mode=False, + count_include_pad=False) + return out diff --git a/mmdeploy/pytorch/ops/__init__.py b/mmdeploy/pytorch/ops/__init__.py index 173ef27516..56f89a621e 100644 --- a/mmdeploy/pytorch/ops/__init__.py +++ b/mmdeploy/pytorch/ops/__init__.py @@ -1,8 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .adaptive_avg_pool import (adaptive_avg_pool1d__default, - adaptive_avg_pool2d__default, - adaptive_avg_pool2d__ncnn, - adaptive_avg_pool3d__default) +from .adaptive_pool import adaptive_avg_pool2d__ncnn from .gelu import gelu__ncnn from .grid_sampler import grid_sampler__default from .hardsigmoid import hardsigmoid__default @@ -15,10 +12,8 @@ from .squeeze import squeeze__default __all__ = [ - 'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default', - 'adaptive_avg_pool3d__default', 'grid_sampler__default', - 'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn', - 'squeeze__default', 'adaptive_avg_pool2d__ncnn', 'gelu__ncnn', - 'layer_norm__ncnn', 'linear__ncnn', '_prepare_onnx_paddings__tensorrt', - 'roll_default' + 'grid_sampler__default', 'hardsigmoid__default', 'instance_norm__tensorrt', + 'generic_rnn__ncnn', 'squeeze__default', 'adaptive_avg_pool2d__ncnn', + 'gelu__ncnn', 'layer_norm__ncnn', 'linear__ncnn', + '_prepare_onnx_paddings__tensorrt', 'roll_default' ] diff --git a/mmdeploy/pytorch/ops/adaptive_avg_pool.py b/mmdeploy/pytorch/ops/adaptive_avg_pool.py deleted file mode 100644 index fc9c86c823..0000000000 --- a/mmdeploy/pytorch/ops/adaptive_avg_pool.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -# Modified from: -# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py - -from torch.nn.modules.utils import _pair, _single, _triple -from torch.onnx.symbolic_helper import parse_args - -from mmdeploy.core import SYMBOLIC_REWRITER - - -def _adaptive_pool(name, type, tuple_fn, fn=None): - """Generic adaptive pooling.""" - - @parse_args('v', 'is') - def symbolic_fn(g, input, output_size): - if output_size == [1] * len(output_size) and type == 'AveragePool': - return g.op('GlobalAveragePool', input) - if not input.isCompleteTensor(): - if output_size == [1] * len(output_size): - return g.op('GlobalMaxPool', input), None - raise NotImplementedError( - '[Adaptive pool]:input size not accessible') - dim = input.type().sizes()[2:] - if output_size == [1] * len(output_size) and type == 'MaxPool': - return g.op('GlobalMaxPool', input), None - - # compute stride = floor(input_size / output_size) - s = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] - - # compute kernel_size = input_size - (output_size - 1) * stride - k = [dim[i] - (output_size[i] - 1) * s[i] for i in range(0, len(dim))] - - # call max_poolxd_with_indices to get indices in the output - if type == 'MaxPool': - return fn(g, input, k, k, (0, ) * len(dim), (1, ) * len(dim), - False) - output = g.op( - type, - input, - kernel_shape_i=tuple_fn(k), - strides_i=tuple_fn(s), - ceil_mode_i=False) - return output - - return symbolic_fn - - -adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', 'AveragePool', - _single) -adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', 'AveragePool', - _pair) -adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool', - _triple) - - -@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool1d', is_pytorch=True) -def adaptive_avg_pool1d__default(ctx, *args): - """Register default symbolic function for `adaptive_avg_pool1d`. - - Align symbolic of adaptive_pool between different torch version. - """ - return adaptive_avg_pool1d(*args) - - -@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool2d', is_pytorch=True) -def adaptive_avg_pool2d__default(ctx, *args): - """Register default symbolic function for `adaptive_avg_pool2d`. - - Align symbolic of adaptive_pool between different torch version. - """ - return adaptive_avg_pool2d(*args) - - -@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool3d', is_pytorch=True) -def adaptive_avg_pool3d__default(ctx, *args): - """Register default symbolic function for `adaptive_avg_pool3d`. - - Align symbolic of adaptive_pool between different torch version. - """ - return adaptive_avg_pool3d(*args) - - -@SYMBOLIC_REWRITER.register_symbolic( - 'adaptive_avg_pool2d', is_pytorch=True, backend='ncnn') -def adaptive_avg_pool2d__ncnn(ctx, g, x, output_size): - """Register ncnn symbolic function for `adaptive_avg_pool2d`. - - Align symbolic of adaptive_avg_pool2d in ncnn. - """ - return g.op('mmdeploy::AdaptiveAvgPool2d', x, output_size) diff --git a/mmdeploy/pytorch/ops/adaptive_pool.py b/mmdeploy/pytorch/ops/adaptive_pool.py new file mode 100644 index 0000000000..d27049576b --- /dev/null +++ b/mmdeploy/pytorch/ops/adaptive_pool.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmdeploy.core import SYMBOLIC_REWRITER + + +@SYMBOLIC_REWRITER.register_symbolic( + 'adaptive_avg_pool2d', is_pytorch=True, backend='ncnn') +def adaptive_avg_pool2d__ncnn(ctx, g, x, output_size): + """Register ncnn symbolic function for `adaptive_avg_pool2d`. + + Align symbolic of adaptive_avg_pool2d in ncnn. + """ + return g.op('mmdeploy::AdaptiveAvgPool2d', x, output_size) diff --git a/tests/test_pytorch/test_pytorch_functions.py b/tests/test_pytorch/test_pytorch_functions.py index 65bda0ecf4..86bee6814b 100644 --- a/tests/test_pytorch/test_pytorch_functions.py +++ b/tests/test_pytorch/test_pytorch_functions.py @@ -341,3 +341,21 @@ def setitem_slice(x, y): nodes = onnx_model.graph.node for node in nodes: assert node.op_type != 'ScatterND' + + +@pytest.mark.parametrize('output_size', [1, 3]) +def test_adaptive_avg_pool2d(output_size): + input = torch.rand(1, 3, 6, 6) + model = WrapFunction(F.adaptive_avg_pool2d, output_size=output_size) + pytorch_output = model(input) + deploy_cfg_ort = mmcv.Config( + dict( + onnx_config=dict(input_shape=None), + backend_config=dict(type='onnxruntime'), + codebase_config=dict(type='mmdet', task='ObjectDetection'))) + rewrite_output, _ = get_rewrite_outputs( + model, + model_inputs={'input': input}, + deploy_cfg=deploy_cfg_ort, + run_with_backend=True) + assert torch.allclose(pytorch_output, rewrite_output[0]) diff --git a/tests/test_pytorch/test_pytorch_ops.py b/tests/test_pytorch/test_pytorch_ops.py index e3e49f3452..41008fc461 100644 --- a/tests/test_pytorch/test_pytorch_ops.py +++ b/tests/test_pytorch/test_pytorch_ops.py @@ -14,9 +14,20 @@ @pytest.fixture(autouse=False, scope='function') def prepare_symbolics(): context = RewriterContext( - Config({'backend_config': { - 'type': 'tensorrt' - }}), 'tensorrt', opset=11) + Config( + dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=None), + backend_config=dict(type='tensorrt'))), + 'tensorrt', + opset=11) context.enter() yield @@ -51,6 +62,8 @@ def forward(self, x): def get_model_onnx_nodes(model, x, onnx_file=onnx_file): torch.onnx.export(model, x, onnx_file, opset_version=11) onnx_model = onnx.load(onnx_file) + import shutil + shutil.copy(onnx_file, './adaptive_avg2d.onnx') nodes = onnx_model.graph.node return nodes @@ -58,18 +71,6 @@ def get_model_onnx_nodes(model, x, onnx_file=onnx_file): @pytest.mark.usefixtures('prepare_symbolics') class TestAdaptivePool: - def test_adaptive_pool_1d_global(self): - x = torch.rand(2, 2, 2) - model = OpModel(torch.nn.functional.adaptive_avg_pool1d, [1]).eval() - nodes = get_model_onnx_nodes(model, x) - assert nodes[0].op_type == 'GlobalAveragePool' - - def test_adaptive_pool_1d(self): - x = torch.rand(2, 2, 2) - model = OpModel(torch.nn.functional.adaptive_avg_pool1d, [2]).eval() - nodes = get_model_onnx_nodes(model, x) - assert nodes[0].op_type == 'AveragePool' - def test_adaptive_pool_2d_global(self): x = torch.rand(2, 2, 2) model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [1, 1]).eval() @@ -80,21 +81,8 @@ def test_adaptive_pool_2d(self): x = torch.rand(2, 2, 2) model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [2, 2]).eval() nodes = get_model_onnx_nodes(model, x) - assert nodes[0].op_type == 'AveragePool' - - def test_adaptive_pool_3d_global(self): - x = torch.rand(2, 2, 2, 2) - model = OpModel(torch.nn.functional.adaptive_avg_pool3d, - [1, 1, 1]).eval() - nodes = get_model_onnx_nodes(model, x) - assert nodes[0].op_type == 'GlobalAveragePool' - - def test_adaptive_pool_3d(self): - x = torch.rand(2, 2, 2, 2) - model = OpModel(torch.nn.functional.adaptive_avg_pool3d, - [2, 2, 2]).eval() - nodes = get_model_onnx_nodes(model, x) - assert nodes[0].op_type == 'AveragePool' + print(nodes) + assert nodes[-1].op_type == 'AveragePool' @pytest.mark.usefixtures('prepare_symbolics_ncnn') @@ -123,6 +111,7 @@ def test_instance_norm(): model = OpModel(torch.group_norm, 1, torch.rand([2]), torch.rand([2]), 1e-05).eval() nodes = get_model_onnx_nodes(model, x) + print(nodes) assert nodes[4].op_type == 'TRTInstanceNormalization' assert nodes[4].domain == 'mmdeploy' From 8be815b83275f8c9ae0034821b7053b4bf291dca Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Wed, 3 Aug 2022 15:57:02 +0800 Subject: [PATCH 2/4] remove debug codes --- tests/test_pytorch/test_pytorch_ops.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_pytorch/test_pytorch_ops.py b/tests/test_pytorch/test_pytorch_ops.py index 41008fc461..868e82411c 100644 --- a/tests/test_pytorch/test_pytorch_ops.py +++ b/tests/test_pytorch/test_pytorch_ops.py @@ -62,8 +62,6 @@ def forward(self, x): def get_model_onnx_nodes(model, x, onnx_file=onnx_file): torch.onnx.export(model, x, onnx_file, opset_version=11) onnx_model = onnx.load(onnx_file) - import shutil - shutil.copy(onnx_file, './adaptive_avg2d.onnx') nodes = onnx_model.graph.node return nodes @@ -81,7 +79,6 @@ def test_adaptive_pool_2d(self): x = torch.rand(2, 2, 2) model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [2, 2]).eval() nodes = get_model_onnx_nodes(model, x) - print(nodes) assert nodes[-1].op_type == 'AveragePool' @@ -111,7 +108,6 @@ def test_instance_norm(): model = OpModel(torch.group_norm, 1, torch.rand([2]), torch.rand([2]), 1e-05).eval() nodes = get_model_onnx_nodes(model, x) - print(nodes) assert nodes[4].op_type == 'TRTInstanceNormalization' assert nodes[4].domain == 'mmdeploy' From 9bb44d68af406e7c8982097140f1cafae9a5ab2a Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Thu, 4 Aug 2022 09:25:05 +0800 Subject: [PATCH 3/4] fix ci --- tests/test_apis/test_onnx_passes.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_apis/test_onnx_passes.py b/tests/test_apis/test_onnx_passes.py index a2d77b4463..d3b1a83e23 100644 --- a/tests/test_apis/test_onnx_passes.py +++ b/tests/test_apis/test_onnx_passes.py @@ -13,6 +13,9 @@ onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name +ort_cfg = dict( + backend_config=dict(type='onnxruntime'), onnx_config=dict(type='onnx')) + def _find_next_node(start: int, nodes: List, op_type: str) -> Tuple[Any, int]: for idx, n in enumerate(nodes[start:]): @@ -166,7 +169,7 @@ def forward(self, x): model = TestModel() x = torch.rand(1, 4, 8, 8) - with RewriterContext({}, onnx_custom_passes=_optimize_onnx): + with RewriterContext(ort_cfg, onnx_custom_passes=_optimize_onnx): torch.onnx.export( model, x, From d3b38294fbdc6a122b8d73936e764db06e1a7150 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Thu, 11 Aug 2022 10:36:37 +0800 Subject: [PATCH 4/4] resolve comment --- mmdeploy/pytorch/functions/__init__.py | 6 ++++-- mmdeploy/pytorch/functions/adaptive_pool.py | 18 ++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index b79304e595..6be8580375 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .adaptive_pool import adaptive_avg_pool2d__default +from .adaptive_pool import (adaptive_avg_pool2d__default, + adaptive_avg_pool2d__ncnn) from .atan2 import atan2__default from .chunk import chunk__ncnn, chunk__torchscript from .expand import expand__ncnn @@ -21,5 +22,6 @@ 'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn', 'triu__default', 'atan2__default', 'normalize__ncnn', 'expand__ncnn', 'chunk__torchscript', 'masked_fill__onnxruntime', - 'tensor__setitem__default', 'adaptive_avg_pool2d__default' + 'tensor__setitem__default', 'adaptive_avg_pool2d__default', + 'adaptive_avg_pool2d__ncnn' ] diff --git a/mmdeploy/pytorch/functions/adaptive_pool.py b/mmdeploy/pytorch/functions/adaptive_pool.py index 977348cfc3..fb09cd82e4 100644 --- a/mmdeploy/pytorch/functions/adaptive_pool.py +++ b/mmdeploy/pytorch/functions/adaptive_pool.py @@ -4,18 +4,13 @@ from torch.nn.modules.utils import _pair from mmdeploy.core import FUNCTION_REWRITER -from mmdeploy.utils import (Backend, get_backend, get_root_logger, - is_dynamic_shape) +from mmdeploy.utils import Backend, get_root_logger, is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.adaptive_avg_pool2d') def adaptive_avg_pool2d__default(ctx, input, output_size): """Rewrite `adaptive_avg_pool2d` for default backend.""" - supported_backends = [Backend.TORCHSCRIPT, Backend.NCNN] - if get_backend(ctx.cfg) in supported_backends: - return ctx.origin_func(input, output_size) - output_size = _pair(output_size) if int(output_size[0]) == int(output_size[1]) == 1: out = ctx.origin_func(input, output_size) @@ -36,3 +31,14 @@ def adaptive_avg_pool2d__default(ctx, input, output_size): ceil_mode=False, count_include_pad=False) return out + + +@FUNCTION_REWRITER.register_rewriter( + func_name='torch.nn.functional.adaptive_avg_pool2d', + backend=Backend.NCNN.value) +@FUNCTION_REWRITER.register_rewriter( + func_name='torch.nn.functional.adaptive_avg_pool2d', + backend=Backend.TORCHSCRIPT.value) +def adaptive_avg_pool2d__ncnn(ctx, input, output_size): + """Rewrite `adaptive_avg_pool2d` for ncnn and torchscript backend.""" + return ctx.origin_func(input, output_size)