From 5708e1066a49aed98d7a3b212a051512cc129c6c Mon Sep 17 00:00:00 2001 From: Quentin <32519815+qant-um@users.noreply.github.com> Date: Sat, 18 Mar 2023 14:29:33 +0100 Subject: [PATCH] Update Resize (opset 11) layer to support scales option when dims are defined (#2137) Signed-off-by: Quentin Muller Co-authored-by: Jay Zhang <36183870+fatcat-z@users.noreply.github.com> --- tf2onnx/onnx_opset/nn.py | 43 ++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 174780db4..0a0e78df2 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -1391,20 +1391,37 @@ def version_11(cls, ctx, node, **kwargs): else: mode = "nearest" roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32)) - const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64)) - const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64)) - const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([]).astype(np.float32)) input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW}) - shape_input = ctx.make_node("Shape", [input_nchw.output[0]]) - sliced_shape = ctx.make_node("Slice", [shape_input.output[0], const_zero.output[0], const_two.output[0]]) - size_int64 = ctx.make_node("Cast", [node.input[1]], attr={"to": onnx_pb.TensorProto.INT64}) - concat_shape = ctx.make_node("Concat", [sliced_shape.output[0], size_int64.output[0]], {'axis': 0}) - resize_inputs = [ - input_nchw.output[0], - roi.output[0], - const_empty_float.output[0], - concat_shape.output[0] - ] + shape = ctx.get_shape(node.input[0]) + if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const(): + target_shape = node.inputs[1].get_tensor_value() + n, h, w, c = shape + nh, nw = target_shape + if "sizes" in node.attr: + sizes_val = np.array([1.0, 1.0, nh, nw]).astype(np.int64) + resize_params = ctx.make_const(utils.make_name("sizes"), sizes_val, raw=False) + else: # scales + scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32) + resize_params = ctx.make_const(utils.make_name("scales"), scale_val, raw=False) + resize_inputs = [ + input_nchw.output[0], + roi.output[0], + resize_params.output[0] + ] + else: + const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64)) + const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64)) + const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([]).astype(np.float32)) + shape_input = ctx.make_node("Shape", [input_nchw.output[0]]) + sliced_shape = ctx.make_node("Slice", [shape_input.output[0], const_zero.output[0], const_two.output[0]]) + size_int64 = ctx.make_node("Cast", [node.input[1]], attr={"to": onnx_pb.TensorProto.INT64}) + concat_shape = ctx.make_node("Concat", [sliced_shape.output[0], size_int64.output[0]], {'axis': 0}) + resize_inputs = [ + input_nchw.output[0], + roi.output[0], + const_empty_float.output[0], + concat_shape.output[0] + ] transformation_mode = "asymmetric" nearest_mode = "floor" if "align_corners" in node.attr and node.attr["align_corners"].i: