From 299014c24f86ad48d355f69693e49a6ee0b3a6e2 Mon Sep 17 00:00:00 2001 From: hanruobing Date: Mon, 13 Jul 2020 18:35:42 +0800 Subject: [PATCH 1/5] add pytorch2onnx part --- mmseg/models/segmentors/encoder_decoder.py | 21 +- setup.cfg | 2 +- tools/onnx_util/__init__.py | 0 tools/onnx_util/symbolic.py | 131 +++++++++ tools/onnx_util/symbolic_helper.py | 300 +++++++++++++++++++++ tools/pytorch2onnx.py | 189 +++++++++++++ 6 files changed, 634 insertions(+), 9 deletions(-) create mode 100644 tools/onnx_util/__init__.py create mode 100644 tools/onnx_util/symbolic.py create mode 100644 tools/onnx_util/symbolic_helper.py create mode 100644 tools/pytorch2onnx.py diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index d3ce17adbb..d1709e0ca3 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.nn.functional as F @@ -171,6 +172,8 @@ def slide_inference(self, img, img_meta, rescale): h_stride, w_stride = self.test_cfg.stride h_crop, w_crop = self.test_cfg.crop_size batch_size, _, h_img, w_img = img.size() + assert h_crop <= h_img and w_crop <= w_img, ( + 'crop size should not greater than image size') num_classes = self.num_classes h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 @@ -185,14 +188,15 @@ def slide_inference(self, img, img_meta, rescale): y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = img[:, :, y1:y2, x1:x2] - pad_img = crop_img.new_zeros( - (crop_img.size(0), crop_img.size(1), h_crop, w_crop)) - pad_img[:, :, :y2 - y1, :x2 - x1] = crop_img - pad_seg_logit = self.encode_decode(pad_img, img_meta) - preds[:, :, y1:y2, - x1:x2] += pad_seg_logit[:, :, :y2 - y1, :x2 - x1] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + count_mat[:, :, y1:y2, x1:x2] += 1 assert (count_mat == 0).sum() == 0 + # We want to regard count_mat as a constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.detach().numpy()) preds = preds / count_mat if rescale: preds = resize( @@ -201,7 +205,6 @@ def slide_inference(self, img, img_meta, rescale): mode='bilinear', align_corners=self.align_corners, warning=False) - return preds def whole_inference(self, img, img_meta, rescale): @@ -243,8 +246,8 @@ def inference(self, img, img_meta, rescale): seg_logit = self.whole_inference(img, img_meta, rescale) output = F.softmax(seg_logit, dim=1) flip = img_meta[0]['flip'] - flip_direction = img_meta[0]['flip_direction'] if flip: + flip_direction = img_meta[0]['flip_direction'] assert flip_direction in ['horizontal', 'vertical'] if flip_direction == 'horizontal': output = output.flip(dims=(3, )) @@ -257,6 +260,8 @@ def simple_test(self, img, img_meta, rescale=True): """Simple test with single image.""" seg_logit = self.inference(img, img_meta, rescale) seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + return seg_pred seg_pred = seg_pred.cpu().numpy() # unravel batch dim seg_pred = list(seg_pred) diff --git a/setup.cfg b/setup.cfg index 2102a8ca60..487c230411 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmseg -known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,pytablewriter,pytest,scipy,torch,torchvision +known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx_util,onnxruntime,pytablewriter,pytest,scipy,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tools/onnx_util/__init__.py b/tools/onnx_util/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/onnx_util/symbolic.py b/tools/onnx_util/symbolic.py new file mode 100644 index 0000000000..c44eb6f802 --- /dev/null +++ b/tools/onnx_util/symbolic.py @@ -0,0 +1,131 @@ +"""Modified from https://github.com/pytorch/pytorch.""" +import onnx_util.symbolic_helper as sym_help +import torch +from torch.onnx.symbolic_helper import parse_args +from torch.onnx.symbolic_registry import register_op + + +def _interpolate(name, dim, interpolate_mode): + + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = sym_help._get_interpolate_attributes( + g, interpolate_mode, args) + align_corners = sym_help._maybe_get_scalar(align_corners) + transformation_mode = 'asymmetric' \ + if interpolate_mode == 'nearest' \ + else 'align_corners' if align_corners else 'pytorch_half_pixel' + empty_tensor = g.op( + 'Constant', value_t=torch.tensor([], dtype=torch.float32)) + + if scales is None: + input_size = g.op('Shape', input) + input_size_beg = sym_help._slice_helper( + g, input_size, axes=[0], ends=[2], starts=[0]) + output_size = g.op( + 'Cast', + output_size, + to_i=sym_help.cast_pytorch_to_onnx['Long']) + output_size = g.op('Concat', input_size_beg, output_size, axis_i=0) + scales = g.op( + 'Constant', value_t=torch.tensor([], dtype=torch.float32)) + return g.op( + 'Resize', + input, + empty_tensor, + # roi only takes effect whith + # coordinate_transformation_mode="tf_crop_and_resize" + scales, # scales is not needed since we are sending out_size + output_size, + coordinate_transformation_mode_s=transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s='floor') # only valid when mode="nearest" + else: + return g.op( + 'Resize', + input, + empty_tensor, + # roi only takes effect with + # coordinate_transformation_mode="tf_crop_and_resize" + scales, # scales is not needed since we are sending out_size + coordinate_transformation_mode_s=transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s='floor') # only valid when mode="nearest" + + return symbolic_fn + + +upsample_nearest1d = _interpolate('upsample_nearest1d', 3, 'nearest') +upsample_nearest2d = _interpolate('upsample_nearest2d', 4, 'nearest') +upsample_nearest3d = _interpolate('upsample_nearest3d', 5, 'nearest') +upsample_linear1d = _interpolate('upsample_linear1d', 3, 'linear') +upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, 'linear') +upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, 'linear') +upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, 'cubic') + + +@parse_args('v', 'v', 'i', 'i', 'i', 'none') +def topk(g, self, k, dim, largest, sorted, out=None): + return sym_help._topk_helper( + g, self, k, dim, largest=largest, sorted=sorted, out=out) + + +def masked_select(g, self, mask): + from torch.onnx.symbolic_opset9 import nonzero, expand_as + index = nonzero(g, expand_as(g, mask, self)) + return g.op('GatherND', self, index) + + +# Modify from Pytorch1.5.0 +def _prepare_onnx_paddings(g, dim, pad): + pad_len = torch.onnx.symbolic_opset9.size( + g, pad, g.op('Constant', value_t=torch.tensor([0]))) + # Set extension = [0] * (dim * 2 - len(pad)) + extension = g.op( + 'Sub', + g.op('Mul', + g.op('Constant', value_t=torch.tensor(dim, dtype=torch.int64)), + g.op('Constant', value_t=torch.tensor(2, dtype=torch.int64))), + pad_len) + pad = g.op('Cast', pad, to_i=sym_help.cast_pytorch_to_onnx['Long']) + paddings = g.op( + 'Concat', + pad, + g.op( + 'ConstantOfShape', + extension, + value_t=torch.tensor([0], dtype=torch.int64)), + axis_i=0) + paddings = g.op('Reshape', paddings, + g.op('Constant', value_t=torch.tensor([-1, 2]))) + paddings = g.op( + 'Transpose', + torch.onnx.symbolic_opset10.flip(g, paddings, [0]), + perm_i=[1, 0]) + paddings = g.op('Reshape', paddings, + g.op('Constant', value_t=torch.tensor([-1]))) + padding_c = g.op( + 'Cast', paddings, to_i=sym_help.cast_pytorch_to_onnx['Long']) + return padding_c + + +def constant_pad_nd(g, input, padding, value=None): + mode = 'constant' + value = sym_help._maybe_get_scalar(value) + value = sym_help._if_scalar_type_as(g, value, input) + pad = _prepare_onnx_paddings(g, input.type().dim(), padding) + return g.op('Pad', input, pad, value, mode_s=mode) + + +def register_extra_symbolics(opset=11): + register_op('topk', topk, '', opset) + register_op('constant_pad_nd', constant_pad_nd, '', opset) + register_op('masked_select', masked_select, '', opset) + register_op('upsample_nearest1d', upsample_nearest1d, '', opset) + register_op('upsample_nearest2d', upsample_nearest2d, '', opset) + register_op('upsample_nearest3d', upsample_nearest3d, '', opset) + register_op('upsample_linear1d', upsample_linear1d, '', opset) + register_op('upsample_bilinear2d', upsample_bilinear2d, '', opset) + register_op('upsample_trilinear3d', upsample_trilinear3d, '', opset) + register_op('upsample_bicubic2d', upsample_bicubic2d, '', opset) diff --git a/tools/onnx_util/symbolic_helper.py b/tools/onnx_util/symbolic_helper.py new file mode 100644 index 0000000000..1980b57d4f --- /dev/null +++ b/tools/onnx_util/symbolic_helper.py @@ -0,0 +1,300 @@ +"""Modified from https://github.com/pytorch/pytorch.""" +from __future__ import absolute_import, division, print_function +import warnings +from functools import wraps +from sys import maxsize as maxsize + +import torch +import torch.onnx +# This import monkey-patches graph manipulation methods on Graph, used for the +# ONNX symbolics +import torch.onnx.utils +from torch._C import ListType + +# --------------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------------- + +# Save some builtins as locals, because we'll shadown them below +_sum = sum + + +def _parse_arg(value, desc): + if desc == 'none': + return value + if desc == 'v' or not _is_value(value): + return value + if value.node().mustBeNone(): + return None + if value.node().kind() == 'onnx::Constant': + tval = value.node()['value'] + if desc == 'i': + return int(tval) + elif desc == 'f': + return float(tval) + elif desc == 'b': + return bool(tval) + elif desc == 's': + return str(tval) + elif desc == 't': + return tval + elif desc == 'is': + return [int(v) for v in tval] + else: + raise RuntimeError( + "ONNX symbolic doesn't know to interpret Constant node") + elif value.node().kind() == 'prim::ListConstruct': + if desc == 'is': + for v in value.node().inputs(): + if v.node().kind() != 'onnx::Constant': + raise RuntimeError( + "Failed to export an ONNX attribute '" + + v.node().kind() + + "', since it's not constant, please try to make " + 'things (e.g., kernel size) static if possible') + return [int(v.node()['value']) for v in value.node().inputs()] + else: + raise RuntimeError( + "ONNX symbolic doesn't know to interpret ListConstruct node") + + raise RuntimeError('Unexpected node type: {}'.format(value.node().kind())) + + +def _maybe_get_const(value, desc): + if _is_value(value) and value.node().kind() == 'onnx::Constant': + return _parse_arg(value, desc) + return value + + +def _maybe_get_scalar(value): + value_t = _maybe_get_const(value, 't') + if isinstance(value_t, torch.Tensor) and value_t.shape == (): + return value_t + return value + + +def _get_const(value, desc, arg_name): + if _is_value(value) and value.node().kind() not in ('onnx::Constant', + 'prim::Constant'): + raise RuntimeError('ONNX symbolic expected a constant' + ' value of the {} argument, got `{}`'.format( + arg_name, value)) + return _parse_arg(value, desc) + + +def _unpack_list(list_value): + list_node = list_value.node() + assert list_node.kind() == 'prim::ListConstruct' + return list(list_node.inputs()) + + +# Check if list_value is output from prim::ListConstruct +# This is usually called before _unpack_list to ensure the list can be +# unpacked. +def _is_packed_list(list_value): + return _is_value( + list_value) and list_value.node().kind() == 'prim::ListConstruct' + + +def parse_args(*arg_descriptors): + + def decorator(fn): + fn._arg_descriptors = arg_descriptors + + def wrapper(g, *args): + # some args may be optional, so the length may be smaller + assert len(arg_descriptors) >= len(args) + args = [ + _parse_arg(arg, arg_desc) + for arg, arg_desc in zip(args, arg_descriptors) + ] + return fn(g, *args) + + # In Python 2 functools.wraps chokes on partially applied functions, so + # we need this as a workaround + try: + wrapper = wraps(fn)(wrapper) + except Exception: + pass + return wrapper + + return decorator + + +def _scalar(x): + """Convert a scalar tensor into a Python value.""" + assert x.numel() == 1 + return x.item() + + +def _if_scalar_type_as(g, self, tensor): + """Convert self into the same type of tensor, as necessary. + + We only support implicit casting for scalars, so we never actually need to + insert an ONNX cast operator here; just fix up the scalar. + """ + if isinstance(self, torch._C.Value): + return self + + scalar_type = tensor.type().scalarType() + if scalar_type: + ty = scalar_type.lower() + return getattr(self, ty)() + + return self + + +def _is_none(x): + return x.node().mustBeNone() + + +def _is_value(x): + return isinstance(x, torch._C.Value) + + +def _is_tensor_list(x): + return x.type().isSubtypeOf(ListType.ofTensors()) + + +def _unimplemented(op, msg): + warnings.warn('ONNX export failed on ' + op + ' because ' + msg + + ' not supported') + + +def _try_get_scalar_type(*args): + for arg in args: + try: + return arg.type().scalarType() + except RuntimeError: + pass + return None + + +def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None): + if out is not None: + _unimplemented('TopK', 'Out parameter is not supported') + if not _is_value(k): + k = g.op('Constant', value_t=torch.tensor([k], dtype=torch.int64)) + else: + k = g.op('Reshape', k, g.op('Constant', value_t=torch.tensor([1]))) + return g.op( + 'TopK', + input, + k, + axis_i=dim, + largest_i=largest, + sorted_i=sorted, + outputs=2) + + +def _slice_helper(g, + input, + axes, + starts, + ends, + steps=None, + dynamic_slice=False): + # TODO(ruobing): add support for opset<10 + from torch.onnx.symbolic_opset10 import _slice + return _slice(g, input, axes, starts, ends, steps, dynamic_slice) + + +def _unsqueeze_helper(g, input, dim): + from torch.onnx.symbolic_opset9 import unsqueeze + return unsqueeze(g, input, dim) + + +def _interpolate_size_to_scales(g, input, output_size, dim): + output_size = _maybe_get_const(output_size, 'is') + if _is_value(output_size): + offset = 2 + offsets = g.op( + 'Constant', value_t=torch.ones(offset, dtype=torch.float32)) + dividend = g.op( + 'Cast', output_size, to_i=cast_pytorch_to_onnx['Float']) + divisor = _slice_helper( + g, g.op('Shape', input), axes=[0], ends=[maxsize], starts=[offset]) + divisor = g.op('Cast', divisor, to_i=cast_pytorch_to_onnx['Float']) + scale_dims = g.op('Div', dividend, divisor) + scales = g.op('Concat', offsets, scale_dims, axis_i=0) + else: + scales_constant = [ + 1. if i < 2 else float(output_size[-(dim - i)]) / + float(input.type().sizes()[-(dim - i)]) for i in range(0, dim) + ] + scales = g.op( + 'Constant', + value_t=torch.tensor(scales_constant, dtype=torch.float32)) + return scales + + +def _interpolate_get_scales_if_available(g, scales): + if len(scales) == 0: + return None + available_scales = _maybe_get_const(scales[0], 'f') != -1 and not _is_none( + scales[0]) + + if not available_scales: + return None + + scales_list = [] + for scale in scales: + unsqueezed_scale = _unsqueeze_helper(g, scale, 0) + # ONNX only supports float for the scales. double -> float. + unsqueezed_scale = g.op( + 'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float']) + scales_list.append(unsqueezed_scale) + offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32)) + scales = g.op('Concat', offsets, *scales_list, axis_i=0) + return scales + + +def _get_interpolate_attributes(g, mode, args): + if mode == 'nearest': + align_corners = None + scales = args[0:] + else: + align_corners = args[0] + scales = args[1:] + scales = _interpolate_get_scales_if_available(g, scales) + return scales, align_corners + + +def _interpolate_get_scales(g, scale_factor, dim): + offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32)) + if isinstance(scale_factor.type(), torch._C.ListType): + return g.op('Concat', offsets, scale_factor, axis_i=0) + else: + scale_factor = _unsqueeze_helper(g, scale_factor, 0) + scale_factor = g.op( + 'Cast', scale_factor, to_i=cast_pytorch_to_onnx['Float']) + scales = [scale_factor for i in range(dim - 2)] + scale_factor = g.op('Concat', offsets, *scales, axis_i=0) + return scale_factor + + +# Metaprogram symbolics for each ATen native specialized cast operator. +# For e.g. we specify a function named `_cast_uint8_t` that instantiates an +# ONNX cast node with `to` attribute 'UINT8' +# +# TODO: remove these once we support Type's in the JIT IR and we can once again +# use the unified toType operator +cast_pytorch_to_onnx = { + 'Byte': torch.onnx.TensorProtoDataType.UINT8, + 'Char': torch.onnx.TensorProtoDataType.INT8, + 'Double': torch.onnx.TensorProtoDataType.DOUBLE, + 'Float': torch.onnx.TensorProtoDataType.FLOAT, + 'Half': torch.onnx.TensorProtoDataType.FLOAT16, + 'Int': torch.onnx.TensorProtoDataType.INT32, + 'Long': torch.onnx.TensorProtoDataType.INT64, + 'Short': torch.onnx.TensorProtoDataType.INT16, + 'Bool': torch.onnx.TensorProtoDataType.BOOL, + 'ComplexFloat': torch.onnx.TensorProtoDataType.COMPLEX64, + 'ComplexDouble': torch.onnx.TensorProtoDataType.COMPLEX128, + 'Undefined': torch.onnx.TensorProtoDataType.UNDEFINED, +} + +# Global set to store the list of quantized operators in the network. +# This is currently only used in the conversion of quantized ops from PT +# -> C2 via ONNX. +_quantized_ops = set() diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py new file mode 100644 index 0000000000..c97dcd93d2 --- /dev/null +++ b/tools/pytorch2onnx.py @@ -0,0 +1,189 @@ +import argparse +from functools import partial + +import mmcv +import numpy as np +import onnxruntime as rt +import torch +import torch._C +import torch.serialization +from mmcv.runner import load_checkpoint +from onnx_util.symbolic import register_extra_symbolics + +from mmseg.models import build_segmentor + +torch.manual_seed(3) + + +def _convert_batchnorm(module): + module_output = module + if isinstance(module, torch.nn.SyncBatchNorm): + module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + # keep requires_grad unchanged + module_output.weight.requires_grad = module.weight.requires_grad + module_output.bias.requires_grad = module.bias.requires_grad + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + for name, child in module.named_children(): + module_output.add_module(name, _convert_batchnorm(child)) + del module + return module_output + + +def _demo_mm_inputs(input_shape, num_classes): # yapf: disable + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + rng = np.random.RandomState(0) + imgs = rng.rand(*input_shape) + segs = rng.randint( + low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) + img_metas = [{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': '.png', + 'scale_factor': 1.0, + 'flip': False, + } for _ in range(N)] + mm_inputs = { + 'imgs': torch.FloatTensor(imgs).requires_grad_(True), + 'img_metas': img_metas, + 'gt_semantic_seg': torch.LongTensor(segs) + } + return mm_inputs + + +def pytorch2onnx(model, + input_shape, + opset_version=11, + show=False, + output_file='tmp.onnx', + verify_onnx=False): + model.cpu().eval() + + num_classes = model.decode_head.num_classes + + mm_inputs = _demo_mm_inputs(input_shape, num_classes) + + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + + img_list = [img[None, :] for img in imgs] + img_meta_list = [[img_meta] for img_meta in img_metas] + + # replace original forward function + origin_forward = model.forward + model.forward = partial( + model.forward, img_metas=img_meta_list, return_loss=False) + + register_extra_symbolics(opset_version) + with torch.no_grad(): + torch.onnx.export( + model, (img_list, ), + output_file, + export_params=True, + keep_initializers_as_inputs=True, + verbose=show, + opset_version=opset_version) + print(f'Successfully exported ONNX model: {output_file}') + model.forward = origin_forward + + if verify_onnx: + # check by onnx + import onnx + onnx_model = onnx.load(output_file) + onnx.checker.check_model(onnx_model) + + # check the numerical value + # get pytorch output + pytorch_result = model(img_list, img_meta_list, return_loss=False)[0] + + # get onnx output + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 1) + sess = rt.InferenceSession(output_file) + onnx_result = sess.run( + None, {net_feed_input[0]: img_list[0].detach().numpy()})[0] + if not (pytorch_result == onnx_result).all(): + raise ValueError( + 'The outputs are different between Pytorch and ONNX') + print('The outputs are same between Pytorch and ONNX') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Convert MMDet to ONNX') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file', default=None) + parser.add_argument('--show', action='store_true', help='show onnx graph') + parser.add_argument( + '--verify', action='store_true', help='verify the onnx model') + parser.add_argument('--output_file', type=str, default='tmp.onnx') + parser.add_argument('--opset_version', type=int, default=11) + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[2048, 1024], + help='input image size') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + cfg = mmcv.Config.fromfile(args.config) + cfg.model.pretrained = None + + # build the model and load checkpoint + segmentor = build_segmentor( + cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) + # convert SyncBN to BN + segmentor = _convert_batchnorm(segmentor) + + num_classes = segmentor.decode_head.num_classes + + if args.checkpoint: + checkpoint = load_checkpoint( + segmentor, args.checkpoint, map_location='cpu') + # old versions did not save class info in checkpoints, + # this walkaround is for backward compatibility + if 'CLASSES' in checkpoint['meta']: + segmentor.CLASSES = checkpoint['meta']['CLASSES'] + + # conver model to onnx file + pytorch2onnx( + segmentor, + input_shape, + opset_version=args.opset_version, + show=args.show, + output_file=args.output_file, + verify_onnx=args.verify) From 6ac9c4a03688650c93acc6d5f7291615fafb0575 Mon Sep 17 00:00:00 2001 From: hanruobing Date: Mon, 20 Jul 2020 10:17:48 +0800 Subject: [PATCH 2/5] Update according to the latest mmcv --- setup.cfg | 2 +- tools/onnx_util/__init__.py | 0 tools/onnx_util/symbolic.py | 131 ------------- tools/onnx_util/symbolic_helper.py | 300 ----------------------------- tools/pytorch2onnx.py | 4 +- 5 files changed, 3 insertions(+), 434 deletions(-) delete mode 100644 tools/onnx_util/__init__.py delete mode 100644 tools/onnx_util/symbolic.py delete mode 100644 tools/onnx_util/symbolic_helper.py diff --git a/setup.cfg b/setup.cfg index 487c230411..9721e1c5c3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmseg -known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx_util,onnxruntime,pytablewriter,pytest,scipy,torch,torchvision +known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnxruntime,pytablewriter,pytest,scipy,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tools/onnx_util/__init__.py b/tools/onnx_util/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tools/onnx_util/symbolic.py b/tools/onnx_util/symbolic.py deleted file mode 100644 index c44eb6f802..0000000000 --- a/tools/onnx_util/symbolic.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Modified from https://github.com/pytorch/pytorch.""" -import onnx_util.symbolic_helper as sym_help -import torch -from torch.onnx.symbolic_helper import parse_args -from torch.onnx.symbolic_registry import register_op - - -def _interpolate(name, dim, interpolate_mode): - - def symbolic_fn(g, input, output_size, *args): - scales, align_corners = sym_help._get_interpolate_attributes( - g, interpolate_mode, args) - align_corners = sym_help._maybe_get_scalar(align_corners) - transformation_mode = 'asymmetric' \ - if interpolate_mode == 'nearest' \ - else 'align_corners' if align_corners else 'pytorch_half_pixel' - empty_tensor = g.op( - 'Constant', value_t=torch.tensor([], dtype=torch.float32)) - - if scales is None: - input_size = g.op('Shape', input) - input_size_beg = sym_help._slice_helper( - g, input_size, axes=[0], ends=[2], starts=[0]) - output_size = g.op( - 'Cast', - output_size, - to_i=sym_help.cast_pytorch_to_onnx['Long']) - output_size = g.op('Concat', input_size_beg, output_size, axis_i=0) - scales = g.op( - 'Constant', value_t=torch.tensor([], dtype=torch.float32)) - return g.op( - 'Resize', - input, - empty_tensor, - # roi only takes effect whith - # coordinate_transformation_mode="tf_crop_and_resize" - scales, # scales is not needed since we are sending out_size - output_size, - coordinate_transformation_mode_s=transformation_mode, - cubic_coeff_a_f=-0.75, # only valid when mode="cubic" - mode_s=interpolate_mode, # nearest, linear, or cubic - nearest_mode_s='floor') # only valid when mode="nearest" - else: - return g.op( - 'Resize', - input, - empty_tensor, - # roi only takes effect with - # coordinate_transformation_mode="tf_crop_and_resize" - scales, # scales is not needed since we are sending out_size - coordinate_transformation_mode_s=transformation_mode, - cubic_coeff_a_f=-0.75, # only valid when mode="cubic" - mode_s=interpolate_mode, # nearest, linear, or cubic - nearest_mode_s='floor') # only valid when mode="nearest" - - return symbolic_fn - - -upsample_nearest1d = _interpolate('upsample_nearest1d', 3, 'nearest') -upsample_nearest2d = _interpolate('upsample_nearest2d', 4, 'nearest') -upsample_nearest3d = _interpolate('upsample_nearest3d', 5, 'nearest') -upsample_linear1d = _interpolate('upsample_linear1d', 3, 'linear') -upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, 'linear') -upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, 'linear') -upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, 'cubic') - - -@parse_args('v', 'v', 'i', 'i', 'i', 'none') -def topk(g, self, k, dim, largest, sorted, out=None): - return sym_help._topk_helper( - g, self, k, dim, largest=largest, sorted=sorted, out=out) - - -def masked_select(g, self, mask): - from torch.onnx.symbolic_opset9 import nonzero, expand_as - index = nonzero(g, expand_as(g, mask, self)) - return g.op('GatherND', self, index) - - -# Modify from Pytorch1.5.0 -def _prepare_onnx_paddings(g, dim, pad): - pad_len = torch.onnx.symbolic_opset9.size( - g, pad, g.op('Constant', value_t=torch.tensor([0]))) - # Set extension = [0] * (dim * 2 - len(pad)) - extension = g.op( - 'Sub', - g.op('Mul', - g.op('Constant', value_t=torch.tensor(dim, dtype=torch.int64)), - g.op('Constant', value_t=torch.tensor(2, dtype=torch.int64))), - pad_len) - pad = g.op('Cast', pad, to_i=sym_help.cast_pytorch_to_onnx['Long']) - paddings = g.op( - 'Concat', - pad, - g.op( - 'ConstantOfShape', - extension, - value_t=torch.tensor([0], dtype=torch.int64)), - axis_i=0) - paddings = g.op('Reshape', paddings, - g.op('Constant', value_t=torch.tensor([-1, 2]))) - paddings = g.op( - 'Transpose', - torch.onnx.symbolic_opset10.flip(g, paddings, [0]), - perm_i=[1, 0]) - paddings = g.op('Reshape', paddings, - g.op('Constant', value_t=torch.tensor([-1]))) - padding_c = g.op( - 'Cast', paddings, to_i=sym_help.cast_pytorch_to_onnx['Long']) - return padding_c - - -def constant_pad_nd(g, input, padding, value=None): - mode = 'constant' - value = sym_help._maybe_get_scalar(value) - value = sym_help._if_scalar_type_as(g, value, input) - pad = _prepare_onnx_paddings(g, input.type().dim(), padding) - return g.op('Pad', input, pad, value, mode_s=mode) - - -def register_extra_symbolics(opset=11): - register_op('topk', topk, '', opset) - register_op('constant_pad_nd', constant_pad_nd, '', opset) - register_op('masked_select', masked_select, '', opset) - register_op('upsample_nearest1d', upsample_nearest1d, '', opset) - register_op('upsample_nearest2d', upsample_nearest2d, '', opset) - register_op('upsample_nearest3d', upsample_nearest3d, '', opset) - register_op('upsample_linear1d', upsample_linear1d, '', opset) - register_op('upsample_bilinear2d', upsample_bilinear2d, '', opset) - register_op('upsample_trilinear3d', upsample_trilinear3d, '', opset) - register_op('upsample_bicubic2d', upsample_bicubic2d, '', opset) diff --git a/tools/onnx_util/symbolic_helper.py b/tools/onnx_util/symbolic_helper.py deleted file mode 100644 index 1980b57d4f..0000000000 --- a/tools/onnx_util/symbolic_helper.py +++ /dev/null @@ -1,300 +0,0 @@ -"""Modified from https://github.com/pytorch/pytorch.""" -from __future__ import absolute_import, division, print_function -import warnings -from functools import wraps -from sys import maxsize as maxsize - -import torch -import torch.onnx -# This import monkey-patches graph manipulation methods on Graph, used for the -# ONNX symbolics -import torch.onnx.utils -from torch._C import ListType - -# --------------------------------------------------------------------------------- -# Helper functions -# --------------------------------------------------------------------------------- - -# Save some builtins as locals, because we'll shadown them below -_sum = sum - - -def _parse_arg(value, desc): - if desc == 'none': - return value - if desc == 'v' or not _is_value(value): - return value - if value.node().mustBeNone(): - return None - if value.node().kind() == 'onnx::Constant': - tval = value.node()['value'] - if desc == 'i': - return int(tval) - elif desc == 'f': - return float(tval) - elif desc == 'b': - return bool(tval) - elif desc == 's': - return str(tval) - elif desc == 't': - return tval - elif desc == 'is': - return [int(v) for v in tval] - else: - raise RuntimeError( - "ONNX symbolic doesn't know to interpret Constant node") - elif value.node().kind() == 'prim::ListConstruct': - if desc == 'is': - for v in value.node().inputs(): - if v.node().kind() != 'onnx::Constant': - raise RuntimeError( - "Failed to export an ONNX attribute '" + - v.node().kind() + - "', since it's not constant, please try to make " - 'things (e.g., kernel size) static if possible') - return [int(v.node()['value']) for v in value.node().inputs()] - else: - raise RuntimeError( - "ONNX symbolic doesn't know to interpret ListConstruct node") - - raise RuntimeError('Unexpected node type: {}'.format(value.node().kind())) - - -def _maybe_get_const(value, desc): - if _is_value(value) and value.node().kind() == 'onnx::Constant': - return _parse_arg(value, desc) - return value - - -def _maybe_get_scalar(value): - value_t = _maybe_get_const(value, 't') - if isinstance(value_t, torch.Tensor) and value_t.shape == (): - return value_t - return value - - -def _get_const(value, desc, arg_name): - if _is_value(value) and value.node().kind() not in ('onnx::Constant', - 'prim::Constant'): - raise RuntimeError('ONNX symbolic expected a constant' - ' value of the {} argument, got `{}`'.format( - arg_name, value)) - return _parse_arg(value, desc) - - -def _unpack_list(list_value): - list_node = list_value.node() - assert list_node.kind() == 'prim::ListConstruct' - return list(list_node.inputs()) - - -# Check if list_value is output from prim::ListConstruct -# This is usually called before _unpack_list to ensure the list can be -# unpacked. -def _is_packed_list(list_value): - return _is_value( - list_value) and list_value.node().kind() == 'prim::ListConstruct' - - -def parse_args(*arg_descriptors): - - def decorator(fn): - fn._arg_descriptors = arg_descriptors - - def wrapper(g, *args): - # some args may be optional, so the length may be smaller - assert len(arg_descriptors) >= len(args) - args = [ - _parse_arg(arg, arg_desc) - for arg, arg_desc in zip(args, arg_descriptors) - ] - return fn(g, *args) - - # In Python 2 functools.wraps chokes on partially applied functions, so - # we need this as a workaround - try: - wrapper = wraps(fn)(wrapper) - except Exception: - pass - return wrapper - - return decorator - - -def _scalar(x): - """Convert a scalar tensor into a Python value.""" - assert x.numel() == 1 - return x.item() - - -def _if_scalar_type_as(g, self, tensor): - """Convert self into the same type of tensor, as necessary. - - We only support implicit casting for scalars, so we never actually need to - insert an ONNX cast operator here; just fix up the scalar. - """ - if isinstance(self, torch._C.Value): - return self - - scalar_type = tensor.type().scalarType() - if scalar_type: - ty = scalar_type.lower() - return getattr(self, ty)() - - return self - - -def _is_none(x): - return x.node().mustBeNone() - - -def _is_value(x): - return isinstance(x, torch._C.Value) - - -def _is_tensor_list(x): - return x.type().isSubtypeOf(ListType.ofTensors()) - - -def _unimplemented(op, msg): - warnings.warn('ONNX export failed on ' + op + ' because ' + msg + - ' not supported') - - -def _try_get_scalar_type(*args): - for arg in args: - try: - return arg.type().scalarType() - except RuntimeError: - pass - return None - - -def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None): - if out is not None: - _unimplemented('TopK', 'Out parameter is not supported') - if not _is_value(k): - k = g.op('Constant', value_t=torch.tensor([k], dtype=torch.int64)) - else: - k = g.op('Reshape', k, g.op('Constant', value_t=torch.tensor([1]))) - return g.op( - 'TopK', - input, - k, - axis_i=dim, - largest_i=largest, - sorted_i=sorted, - outputs=2) - - -def _slice_helper(g, - input, - axes, - starts, - ends, - steps=None, - dynamic_slice=False): - # TODO(ruobing): add support for opset<10 - from torch.onnx.symbolic_opset10 import _slice - return _slice(g, input, axes, starts, ends, steps, dynamic_slice) - - -def _unsqueeze_helper(g, input, dim): - from torch.onnx.symbolic_opset9 import unsqueeze - return unsqueeze(g, input, dim) - - -def _interpolate_size_to_scales(g, input, output_size, dim): - output_size = _maybe_get_const(output_size, 'is') - if _is_value(output_size): - offset = 2 - offsets = g.op( - 'Constant', value_t=torch.ones(offset, dtype=torch.float32)) - dividend = g.op( - 'Cast', output_size, to_i=cast_pytorch_to_onnx['Float']) - divisor = _slice_helper( - g, g.op('Shape', input), axes=[0], ends=[maxsize], starts=[offset]) - divisor = g.op('Cast', divisor, to_i=cast_pytorch_to_onnx['Float']) - scale_dims = g.op('Div', dividend, divisor) - scales = g.op('Concat', offsets, scale_dims, axis_i=0) - else: - scales_constant = [ - 1. if i < 2 else float(output_size[-(dim - i)]) / - float(input.type().sizes()[-(dim - i)]) for i in range(0, dim) - ] - scales = g.op( - 'Constant', - value_t=torch.tensor(scales_constant, dtype=torch.float32)) - return scales - - -def _interpolate_get_scales_if_available(g, scales): - if len(scales) == 0: - return None - available_scales = _maybe_get_const(scales[0], 'f') != -1 and not _is_none( - scales[0]) - - if not available_scales: - return None - - scales_list = [] - for scale in scales: - unsqueezed_scale = _unsqueeze_helper(g, scale, 0) - # ONNX only supports float for the scales. double -> float. - unsqueezed_scale = g.op( - 'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float']) - scales_list.append(unsqueezed_scale) - offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32)) - scales = g.op('Concat', offsets, *scales_list, axis_i=0) - return scales - - -def _get_interpolate_attributes(g, mode, args): - if mode == 'nearest': - align_corners = None - scales = args[0:] - else: - align_corners = args[0] - scales = args[1:] - scales = _interpolate_get_scales_if_available(g, scales) - return scales, align_corners - - -def _interpolate_get_scales(g, scale_factor, dim): - offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32)) - if isinstance(scale_factor.type(), torch._C.ListType): - return g.op('Concat', offsets, scale_factor, axis_i=0) - else: - scale_factor = _unsqueeze_helper(g, scale_factor, 0) - scale_factor = g.op( - 'Cast', scale_factor, to_i=cast_pytorch_to_onnx['Float']) - scales = [scale_factor for i in range(dim - 2)] - scale_factor = g.op('Concat', offsets, *scales, axis_i=0) - return scale_factor - - -# Metaprogram symbolics for each ATen native specialized cast operator. -# For e.g. we specify a function named `_cast_uint8_t` that instantiates an -# ONNX cast node with `to` attribute 'UINT8' -# -# TODO: remove these once we support Type's in the JIT IR and we can once again -# use the unified toType operator -cast_pytorch_to_onnx = { - 'Byte': torch.onnx.TensorProtoDataType.UINT8, - 'Char': torch.onnx.TensorProtoDataType.INT8, - 'Double': torch.onnx.TensorProtoDataType.DOUBLE, - 'Float': torch.onnx.TensorProtoDataType.FLOAT, - 'Half': torch.onnx.TensorProtoDataType.FLOAT16, - 'Int': torch.onnx.TensorProtoDataType.INT32, - 'Long': torch.onnx.TensorProtoDataType.INT64, - 'Short': torch.onnx.TensorProtoDataType.INT16, - 'Bool': torch.onnx.TensorProtoDataType.BOOL, - 'ComplexFloat': torch.onnx.TensorProtoDataType.COMPLEX64, - 'ComplexDouble': torch.onnx.TensorProtoDataType.COMPLEX128, - 'Undefined': torch.onnx.TensorProtoDataType.UNDEFINED, -} - -# Global set to store the list of quantized operators in the network. -# This is currently only used in the conversion of quantized ops from PT -# -> C2 via ONNX. -_quantized_ops = set() diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index c97dcd93d2..dcb3badf15 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -7,8 +7,8 @@ import torch import torch._C import torch.serialization +from mmcv.onnx import register_extra_symbolics from mmcv.runner import load_checkpoint -from onnx_util.symbolic import register_extra_symbolics from mmseg.models import build_segmentor @@ -141,7 +141,7 @@ def parse_args(): '--shape', type=int, nargs='+', - default=[2048, 1024], + default=[256, 256], help='input image size') args = parser.parse_args() return args From 07e5fe5eae0f4f932a4f2824d39bcc02ea6fbf61 Mon Sep 17 00:00:00 2001 From: hanruobing Date: Tue, 21 Jul 2020 14:21:21 +0800 Subject: [PATCH 3/5] add docstring --- tools/pytorch2onnx.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index dcb3badf15..d42f68bdcc 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -72,7 +72,21 @@ def pytorch2onnx(model, opset_version=11, show=False, output_file='tmp.onnx', - verify_onnx=False): + verify=False): + """Export Pytorch model to ONNX model and verify the outputs are same + between Pytorch and ONNX. + + Args: + model (nn.Module): Pytorch model we want to export. + input_shape (tuple): Use this input shape to construct + the corresponding dummy input and execute the model. + opset_version (int): The onnx op version. Default: 11. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the output ONNX model. + Default: `tmp.onnx`. + verify (bool): Whether compare the outputs between Pytorch and ONNX. + Default: False. + """ model.cpu().eval() num_classes = model.decode_head.num_classes @@ -102,7 +116,7 @@ def pytorch2onnx(model, print(f'Successfully exported ONNX model: {output_file}') model.forward = origin_forward - if verify_onnx: + if verify: # check by onnx import onnx onnx_model = onnx.load(output_file) @@ -122,7 +136,7 @@ def pytorch2onnx(model, sess = rt.InferenceSession(output_file) onnx_result = sess.run( None, {net_feed_input[0]: img_list[0].detach().numpy()})[0] - if not (pytorch_result == onnx_result).all(): + if not np.allclose(pytorch_result, onnx_result): raise ValueError( 'The outputs are different between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX') @@ -135,8 +149,8 @@ def parse_args(): parser.add_argument('--show', action='store_true', help='show onnx graph') parser.add_argument( '--verify', action='store_true', help='verify the onnx model') - parser.add_argument('--output_file', type=str, default='tmp.onnx') - parser.add_argument('--opset_version', type=int, default=11) + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--opset-version', type=int, default=11) parser.add_argument( '--shape', type=int, @@ -174,10 +188,6 @@ def parse_args(): if args.checkpoint: checkpoint = load_checkpoint( segmentor, args.checkpoint, map_location='cpu') - # old versions did not save class info in checkpoints, - # this walkaround is for backward compatibility - if 'CLASSES' in checkpoint['meta']: - segmentor.CLASSES = checkpoint['meta']['CLASSES'] # conver model to onnx file pytorch2onnx( @@ -186,4 +196,4 @@ def parse_args(): opset_version=args.opset_version, show=args.show, output_file=args.output_file, - verify_onnx=args.verify) + verify=args.verify) From 000d32958e2db18f934f0be06c160d949c567e1e Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Fri, 14 Aug 2020 03:25:51 +0800 Subject: [PATCH 4/5] update docs --- docs/getting_started.md | 15 +++++++++++++++ tools/pytorch2onnx.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/getting_started.md b/docs/getting_started.md index 3a9b656032..2c3d8d562a 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -342,3 +342,18 @@ python tools/publish_model.py work_dirs/pspnet/latest.pth psp_r50_hszhao_200ep.p ``` The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pth`. + +### Convert to ONNX (experimental) + +We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model. + +```shell +python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output_file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify] +``` + +**Note**: This tool is still experimental. Some customized operators are not supported for now. + +## Tutorials + +Currently, we provide four tutorials for users to [add new dataset](tutorials/new_dataset.md), [design data pipeline](tutorials/data_pipeline.md) and [add new modules](tutorials/new_modules.md), [use training tricks](tutorials/training_tricks.md). +We also provide a full description about the [config system](config.md). diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index d42f68bdcc..dd0ea4f54b 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -36,7 +36,7 @@ def _convert_batchnorm(module): return module_output -def _demo_mm_inputs(input_shape, num_classes): # yapf: disable +def _demo_mm_inputs(input_shape, num_classes): """Create a superset of inputs needed to run test or train batches. Args: From b7062a55d2c14f87b4315baba520aeda21feb6ed Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Fri, 14 Aug 2020 03:26:19 +0800 Subject: [PATCH 5/5] update docs --- tools/pytorch2onnx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index dd0ea4f54b..df84eeb911 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -42,7 +42,6 @@ def _demo_mm_inputs(input_shape, num_classes): Args: input_shape (tuple): input batch dimensions - num_classes (int): number of semantic classes """