Skip to content

Commit

Permalink
Update Resize (opset 11) layer to support scales option when dims are…
Browse files Browse the repository at this point in the history
… defined (#2137)

Signed-off-by: Quentin Muller <[email protected]>
Co-authored-by: Jay Zhang <[email protected]>
  • Loading branch information
qant-um and fatcat-z authored Mar 18, 2023
1 parent ec01956 commit 5708e10
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,20 +1391,37 @@ def version_11(cls, ctx, node, **kwargs):
else:
mode = "nearest"
roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32))
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64))
const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64))
const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([]).astype(np.float32))
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
shape_input = ctx.make_node("Shape", [input_nchw.output[0]])
sliced_shape = ctx.make_node("Slice", [shape_input.output[0], const_zero.output[0], const_two.output[0]])
size_int64 = ctx.make_node("Cast", [node.input[1]], attr={"to": onnx_pb.TensorProto.INT64})
concat_shape = ctx.make_node("Concat", [sliced_shape.output[0], size_int64.output[0]], {'axis': 0})
resize_inputs = [
input_nchw.output[0],
roi.output[0],
const_empty_float.output[0],
concat_shape.output[0]
]
shape = ctx.get_shape(node.input[0])
if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const():
target_shape = node.inputs[1].get_tensor_value()
n, h, w, c = shape
nh, nw = target_shape
if "sizes" in node.attr:
sizes_val = np.array([1.0, 1.0, nh, nw]).astype(np.int64)
resize_params = ctx.make_const(utils.make_name("sizes"), sizes_val, raw=False)
else: # scales
scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32)
resize_params = ctx.make_const(utils.make_name("scales"), scale_val, raw=False)
resize_inputs = [
input_nchw.output[0],
roi.output[0],
resize_params.output[0]
]
else:
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64))
const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64))
const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([]).astype(np.float32))
shape_input = ctx.make_node("Shape", [input_nchw.output[0]])
sliced_shape = ctx.make_node("Slice", [shape_input.output[0], const_zero.output[0], const_two.output[0]])
size_int64 = ctx.make_node("Cast", [node.input[1]], attr={"to": onnx_pb.TensorProto.INT64})
concat_shape = ctx.make_node("Concat", [sliced_shape.output[0], size_int64.output[0]], {'axis': 0})
resize_inputs = [
input_nchw.output[0],
roi.output[0],
const_empty_float.output[0],
concat_shape.output[0]
]
transformation_mode = "asymmetric"
nearest_mode = "floor"
if "align_corners" in node.attr and node.attr["align_corners"].i:
Expand Down

0 comments on commit 5708e10

Please sign in to comment.