diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 2b1ce5442..027803781 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -1391,20 +1391,37 @@ def version_11(cls, ctx, node, **kwargs): else: mode = "nearest" roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32)) - const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64)) - const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64)) - const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([]).astype(np.float32)) input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW}) - shape_input = ctx.make_node("Shape", [input_nchw.output[0]]) - sliced_shape = ctx.make_node("Slice", [shape_input.output[0], const_zero.output[0], const_two.output[0]]) - size_int64 = ctx.make_node("Cast", [node.input[1]], attr={"to": onnx_pb.TensorProto.INT64}) - concat_shape = ctx.make_node("Concat", [sliced_shape.output[0], size_int64.output[0]], {'axis': 0}) - resize_inputs = [ - input_nchw.output[0], - roi.output[0], - const_empty_float.output[0], - concat_shape.output[0] - ] + shape = ctx.get_shape(node.input[0]) + if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const(): + target_shape = node.inputs[1].get_tensor_value() + n, h, w, c = shape + nh, nw = target_shape + if "sizes" in node.attr: + sizes_val = np.array([1.0, 1.0, nh, nw]).astype(np.int64) + resize_params = ctx.make_const(utils.make_name("sizes"), sizes_val, raw=False) + else: # scales + scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32) + resize_params = ctx.make_const(utils.make_name("scales"), scale_val, raw=False) + resize_inputs = [ + input_nchw.output[0], + roi.output[0], + resize_params.output[0] + ] + else: + const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64)) + const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64)) + const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([]).astype(np.float32)) + shape_input = ctx.make_node("Shape", [input_nchw.output[0]]) + sliced_shape = ctx.make_node("Slice", [shape_input.output[0], const_zero.output[0], const_two.output[0]]) + size_int64 = ctx.make_node("Cast", [node.input[1]], attr={"to": onnx_pb.TensorProto.INT64}) + concat_shape = ctx.make_node("Concat", [sliced_shape.output[0], size_int64.output[0]], {'axis': 0}) + resize_inputs = [ + input_nchw.output[0], + roi.output[0], + const_empty_float.output[0], + concat_shape.output[0] + ] transformation_mode = "asymmetric" nearest_mode = "floor" if "align_corners" in node.attr and node.attr["align_corners"].i: