Skip to content

Commit

Permalink
pass img_metas while exporting to onnx (#681)
Browse files Browse the repository at this point in the history
* pass img_metas while exporting to onnx

* remove try-catch in tools for beter debugging

* use get

* fix typo
  • Loading branch information
RunningLeon authored Jun 30, 2022
1 parent 5195ff9 commit 17a7d60
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 47 deletions.
2 changes: 2 additions & 0 deletions mmdeploy/apis/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def torch2onnx(img: Any,

torch_model = task_processor.init_pytorch_model(model_checkpoint)
data, model_inputs = task_processor.create_input(img, input_shape)
input_metas = dict(img_metas=data.get('img_metas', None))
if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1:
model_inputs = model_inputs[0]

Expand All @@ -87,6 +88,7 @@ def torch2onnx(img: Any,
export(
torch_model,
model_inputs,
input_metas=input_metas,
output_path_prefix=output_prefix,
backend=backend,
input_names=input_names,
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/codebase/mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
'mmcls.models.classifiers.ImageClassifier.forward', backend='default')
@FUNCTION_REWRITER.register_rewriter(
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
def base_classifier__forward(ctx, self, img, *args, **kwargs):
def base_classifier__forward(ctx, self, img, return_loss=False, **kwargs):
"""Rewrite `forward` of BaseClassifier for default backend.
Rewrite this function to call simple_test function,
Expand All @@ -23,5 +23,5 @@ def base_classifier__forward(ctx, self, img, *args, **kwargs):
result(Tensor): The result of classifier.The tensor
shape (batch_size,num_classes).
"""
result = self.simple_test(img, {})
result = self.simple_test(img, **kwargs)
return result
23 changes: 12 additions & 11 deletions mmdeploy/codebase/mmdet/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@

@mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
def __forward_impl(ctx, self, img, img_metas, **kwargs):
"""Rewrite and adding mark for `forward`.
Encapsulate this function for rewriting `forward` of BaseDetector.
1. Add mark for BaseDetector.
2. Support both dynamic and static export to onnx.
"""
assert isinstance(img_metas, dict)
assert isinstance(img, torch.Tensor)

deploy_cfg = ctx.cfg
Expand All @@ -23,14 +22,18 @@ def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
img_shape = torch._shape_as_tensor(img)[2:]
if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape]
img_metas['img_shape'] = img_shape
img_metas = [img_metas]
img_metas[0]['img_shape'] = img_shape
return self.simple_test(img, img_metas, **kwargs)


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.base.BaseDetector.forward')
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
def base_detector__forward(ctx,
self,
img,
img_metas=None,
return_loss=False,
**kwargs):
"""Rewrite `forward` of BaseDetector for default backend.
Rewrite this function to:
Expand All @@ -56,14 +59,12 @@ def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
corresponds to each class.
"""
if img_metas is None:
img_metas = {}

while isinstance(img_metas, list):
img_metas = [{}]
else:
assert len(img_metas) == 1, 'do not support aug_test'
img_metas = img_metas[0]

if isinstance(img, list):
img = torch.cat(img, 0)
img = img[0]

if 'return_loss' in kwargs:
kwargs.pop('return_loss')
return __forward_impl(ctx, self, img, img_metas=img_metas, **kwargs)
10 changes: 5 additions & 5 deletions mmdeploy/codebase/mmseg/models/segmentors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
"""
if img_metas is None:
img_metas = {}
while isinstance(img_metas, list):
img_metas = [{}]
else:
assert len(img_metas) == 1, 'do not support aug_test'
img_metas = img_metas[0]

if isinstance(img, list):
img = torch.cat(img, 0)
img = img[0]
assert isinstance(img, torch.Tensor)

deploy_cfg = ctx.cfg
Expand All @@ -37,5 +37,5 @@ def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
img_shape = img.shape[2:]
if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape]
img_metas['img_shape'] = img_shape
img_metas[0]['img_shape'] = img_shape
return self.simple_test(img, img_metas, **kwargs)
2 changes: 1 addition & 1 deletion tests/test_codebase/test_mmcls/test_mmcls_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def extract_feat(self, imgs):
def forward_train(self, imgs):
return 'train'

def simple_test(self, img, tmp, **kwargs):
def simple_test(self, img, tmp=None, **kwargs):
return 'simple_test'

model = DummyClassifier().eval()
Expand Down
8 changes: 2 additions & 6 deletions tools/onnx2ncnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ def main():
output_prefix = args.output_prefix

logger.info(f'onnx2ncnn: \n\tonnx_path: {onnx_path} ')
try:
from_onnx(onnx_path, output_prefix)
logger.info('onnx2ncnn success.')
except Exception as e:
logger.error(e)
logger.error('onnx2ncnn failed.')
from_onnx(onnx_path, output_prefix)
logger.info('onnx2ncnn success.')


if __name__ == '__main__':
Expand Down
10 changes: 3 additions & 7 deletions tools/onnx2pplnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,11 @@ def main():
if isinstance(input_shapes[0], int):
input_shapes = [input_shapes]

logger.info(f'onnx2ppl: \n\tonnx_path: {onnx_path} '
logger.info(f'onnx2pplnn: \n\tonnx_path: {onnx_path} '
f'\n\toutput_prefix: {output_prefix}'
f'\n\topt_shapes: {input_shapes}')
try:
from_onnx(onnx_path, output_prefix, device, input_shapes)
logger.info('onnx2tpplnn success.')
except Exception as e:
logger.error(e)
logger.error('onnx2tpplnn failed.')
from_onnx(onnx_path, output_prefix, device, input_shapes)
logger.info('onnx2pplnn success.')


if __name__ == '__main__':
Expand Down
26 changes: 11 additions & 15 deletions tools/onnx2tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,18 @@ def main():

logger.info(f'onnx2tensorrt: \n\tonnx_path: {onnx_path} '
f'\n\tdeploy_cfg: {deploy_cfg_path}')
try:
from_onnx(
onnx_path,
output_prefix,
input_shapes=final_params['input_shapes'],
log_level=get_trt_log_level(),
fp16_mode=final_params.get('fp16_mode', False),
int8_mode=final_params.get('int8_mode', False),
int8_param=int8_param,
max_workspace_size=final_params.get('max_workspace_size', 0),
device_id=device_id)
from_onnx(
onnx_path,
output_prefix,
input_shapes=final_params['input_shapes'],
log_level=get_trt_log_level(),
fp16_mode=final_params.get('fp16_mode', False),
int8_mode=final_params.get('int8_mode', False),
int8_param=int8_param,
max_workspace_size=final_params.get('max_workspace_size', 0),
device_id=device_id)

logger.info('onnx2tensorrt success.')
except Exception as e:
logger.error(e)
logger.error('onnx2tensorrt failed.')
logger.info('onnx2tensorrt success.')


if __name__ == '__main__':
Expand Down

0 comments on commit 17a7d60

Please sign in to comment.