-
Notifications
You must be signed in to change notification settings - Fork 647
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix adaptive_avg_pool exporting to onnx (#857)
* fix adaptive_avg_pool exporting to onnx * remove debug codes * fix ci * resolve comment
- Loading branch information
1 parent
5fb342e
commit 670a504
Showing
13 changed files
with
104 additions
and
221 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from .backbones import * # noqa: F401,F403 | ||
from .detectors import * # noqa: F401,F403 | ||
from .heads import * # noqa: F401,F403 |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .aspp_head import aspp_head__forward | ||
from .ema_head import ema_module__forward | ||
from .psp_head import ppm__forward | ||
|
||
__all__ = ['aspp_head__forward', 'ppm__forward', 'ema_module__forward'] | ||
__all__ = ['aspp_head__forward', 'ema_module__forward'] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import torch.nn.functional as F | ||
from torch.nn.modules.utils import _pair | ||
|
||
from mmdeploy.core import FUNCTION_REWRITER | ||
from mmdeploy.utils import Backend, get_root_logger, is_dynamic_shape | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='torch.nn.functional.adaptive_avg_pool2d') | ||
def adaptive_avg_pool2d__default(ctx, input, output_size): | ||
"""Rewrite `adaptive_avg_pool2d` for default backend.""" | ||
output_size = _pair(output_size) | ||
if int(output_size[0]) == int(output_size[1]) == 1: | ||
out = ctx.origin_func(input, output_size) | ||
else: | ||
deploy_cfg = ctx.cfg | ||
is_dynamic_flag = is_dynamic_shape(deploy_cfg) | ||
if is_dynamic_flag: | ||
logger = get_root_logger() | ||
logger.warning('`adaptive_avg_pool2d` would be ' | ||
'replaced to `avg_pool2d` explicitly') | ||
size = input.shape[2:] | ||
k = [int(size[i] / output_size[i]) for i in range(0, len(size))] | ||
out = F.avg_pool2d( | ||
input, | ||
kernel_size=k, | ||
stride=k, | ||
padding=0, | ||
ceil_mode=False, | ||
count_include_pad=False) | ||
return out | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='torch.nn.functional.adaptive_avg_pool2d', | ||
backend=Backend.NCNN.value) | ||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='torch.nn.functional.adaptive_avg_pool2d', | ||
backend=Backend.TORCHSCRIPT.value) | ||
def adaptive_avg_pool2d__ncnn(ctx, input, output_size): | ||
"""Rewrite `adaptive_avg_pool2d` for ncnn and torchscript backend.""" | ||
return ctx.origin_func(input, output_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from mmdeploy.core import SYMBOLIC_REWRITER | ||
|
||
|
||
@SYMBOLIC_REWRITER.register_symbolic( | ||
'adaptive_avg_pool2d', is_pytorch=True, backend='ncnn') | ||
def adaptive_avg_pool2d__ncnn(ctx, g, x, output_size): | ||
"""Register ncnn symbolic function for `adaptive_avg_pool2d`. | ||
Align symbolic of adaptive_avg_pool2d in ncnn. | ||
""" | ||
return g.op('mmdeploy::AdaptiveAvgPool2d', x, output_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters