Skip to content

Commit

Permalink
[TARGET] Add layout_transform, clip and expand_dims in onnx converter (
Browse files Browse the repository at this point in the history
…#6366)

* Add layout_transform, clip and expand_dims in onnx converter

* remove _add_input and address comments

* address comments
  • Loading branch information
Xingyu Zhou authored Sep 5, 2020
1 parent 3451ccb commit 9aa69e2
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 26 deletions.
148 changes: 122 additions & 26 deletions python/tvm/contrib/target/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ def call_node_infer_type(node):
return types


def add_input(data, name, model_container):
def add_input(data, name, prefix, model_container):
input_name = '{}_{}'.format(prefix, name)
dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[data.dtype]
tensor_value_info = onnx.helper.make_tensor_value_info(name, dtype, shape=data.shape)
tensor_value_info = onnx.helper.make_tensor_value_info(input_name, dtype, shape=data.shape)
model_container.add_inputs([tensor_value_info])
data_tensor = numpy_helper.from_array(data, name)
data_tensor = numpy_helper.from_array(data, input_name)
model_container.add_initializers([data_tensor])
return input_name


class OpConverter(object):
Expand Down Expand Up @@ -111,14 +113,16 @@ def convert(cls, node_entry, model_container, node_dict):
Relay operator accepts shape as attribute but ONNX operator
accepts it as a input.
"""

name = node_entry['name']
shape = numpy.asarray([a.value for a in node_entry['relay_node'].attrs.newshape],
dtype=numpy.int64)
input_name = 'shape{}'.format(node_entry['name'])
node = onnx.helper.make_node(cls.__name__, [node_entry['input_names'][0], input_name],

input_names = [node_entry['input_names'][0],
add_input(shape, name, 'shape', model_container)]

node = onnx.helper.make_node(cls.__name__, input_names,
node_entry['output_names'])
model_container.add_nodes([node])
add_input(shape, input_name, model_container)


class Conv(OpConverter):
Expand Down Expand Up @@ -349,13 +353,12 @@ def convert(cls, node_entry, model_container, node_dict):

name = node_entry['name']
data = numpy.asarray(attrs['pads'], dtype=attrs['pads'][0].dtype).astype(numpy.int64)
input_name = 'pads_{}'.format(name)
value = numpy.dtype(node_entry['types'][0].dtype).type(attrs['constant_value'])
input_value_name = 'value_{}'.format(name)
add_input(data, input_name, model_container)
add_input(value, input_value_name, model_container)

input_names = [node_entry['input_names'][0], input_name, input_value_name]
input_names = [node_entry['input_names'][0],
add_input(data, name, 'pads', model_container),
add_input(value, name, 'value', model_container)]

node = onnx.helper.make_node(cls.__name__, input_names, node_entry['output_names'])
model_container.add_nodes([node])

Expand Down Expand Up @@ -440,17 +443,16 @@ def convert(cls, node_entry, model_container, node_dict):
else:
steps += [1] * (len(shape) - len(steps))

def _add_input(val, input_name):
val_arr = numpy.asarray(val).astype(numpy.int64)
input_name = '{}_{}'.format(name, input_name)
add_input(val_arr, input_name, model_container)
return input_name
starts = numpy.asarray(starts).astype(numpy.int64)
ends = numpy.asarray(ends).astype(numpy.int64)
axes = numpy.asarray(axes).astype(numpy.int64)
steps = numpy.asarray(steps).astype(numpy.int64)

input_names = []
input_names.append(_add_input(starts, 'starts'))
input_names.append(_add_input(ends, 'ends'))
input_names.append(_add_input(axes, 'axes'))
input_names.append(_add_input(steps, 'steps'))
input_names.append(add_input(starts, name, 'starts', model_container))
input_names.append(add_input(ends, name, 'ends', model_container))
input_names.append(add_input(axes, name, 'axes', model_container))
input_names.append(add_input(steps, name, 'steps', model_container))

input_names = [node_entry['input_names'][0]] + input_names

Expand Down Expand Up @@ -511,6 +513,94 @@ def convert(cls, node_entry, model_container, node_dict):
model_container.add_nodes([slice_node])


class LayoutTransform(OpConverter):
""" Operator converter for Layouttransform
"""

@classmethod
def convert_attributes(cls, attrs):
src_layout = attrs.get_str("src_layout")
dst_layout = attrs.get_str("dst_layout")

perm = [src_layout.index(c) for c in dst_layout]
return {'perm': tuple(perm)}

@classmethod
def convert(cls, node_entry, model_container, node_dict):
attrs = cls.convert_attributes(node_entry['relay_node'].attrs)
onnx_node = onnx.helper.make_node("Transpose",
node_entry['input_names'],
node_entry['output_names'],
**attrs)
model_container.add_nodes([onnx_node])


class Clip(OpConverter):
""" Operator converter for Clip.
"""

@classmethod
def convert_attributes(cls, attrs):
return {
'min': attrs.a_min,
'max': attrs.a_max
}

@classmethod
def convert(cls, node_entry, model_container, node_dict):
attrs = cls.convert_attributes(node_entry['relay_node'].attrs)

name = node_entry['name']

min_val = numpy.asarray(attrs['min']).astype(numpy.float32)
max_val = numpy.asarray(attrs['max']).astype(numpy.float32)

input_names = []
input_names.append(add_input(min_val, name, 'min', model_container))
input_names.append(add_input(max_val, name, 'max', model_container))

input_names = [node_entry['input_names'][0]] + input_names

node = onnx.helper.make_node(cls.__name__, input_names, node_entry['output_names'])
model_container.add_nodes([node])


class Expand(OpConverter):
""" Operator converter for Expand_dims.
"""

@classmethod
def convert_attributes(cls, attrs):
return {
'axis': attrs.axis,
'num_newaxis': attrs.num_newaxis
}

@classmethod
def convert(cls, node_entry, model_container, node_dict):
attrs = cls.convert_attributes(node_entry['relay_node'].attrs)

name = node_entry['name']

input_node = node_dict[node_entry['inputs'][0]]
assert len(input_node) == 1, "input node_entry can not be a Tuple"
input_node = input_node[0]
data_shape = input_node['types'][0].shape
new_shape = list(data_shape)

for _ in range(attrs['num_newaxis']):
new_shape.insert(attrs['axis'], 1)

new_shape = numpy.asarray(new_shape).astype(numpy.int64)
input_names = []
input_names.append(add_input(new_shape, name, 'shape', model_container))

input_names = [node_entry['input_names'][0]] + input_names

node = onnx.helper.make_node(cls.__name__, input_names, node_entry['output_names'])
model_container.add_nodes([node])


class ConstantOfShapeZeros(OpConverter):
""" Operator converter for ConstantOfShape.
"""
Expand All @@ -528,17 +618,20 @@ def convert(cls, node_entry, model_container, node_dict):
assert len(input_node) == 1, "input node can not be a Tuple"
input_node = input_node[0]
dtype = input_node['types'][0].dtype
input_shape_name = 'shape_{}'.format(node_entry['name'])

name = node_entry['name']
shape = [val.value for val in input_node['types'][0].shape]
shape = numpy.asarray(shape).astype(numpy.int64)
add_input(shape, input_shape_name, model_container)

input_names = []
input_names.append(add_input(shape, name, 'shape', model_container))

dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(dtype)]
tensor_value = onnx.helper.make_tensor("value", dtype,
[1], [attrs['value']])

node = onnx.helper.make_node('ConstantOfShape',
[input_shape_name],
input_names,
node_entry['output_names'],
value=tensor_value)
model_container.add_nodes([node])
Expand Down Expand Up @@ -584,7 +677,10 @@ def convert_attributes(cls, attrs):
'ones_like': ConstantOfShapeOnes,
'subtract': rename('Sub'),
'split': Split,
'exp': rename('Exp')
'exp': rename('Exp'),
'layout_transform': LayoutTransform,
'clip': Clip,
'expand_dims': Expand
}


Expand Down Expand Up @@ -670,7 +766,7 @@ def _get_node_entry(cls, relay_node, name):
"input_names": [name], # input names in case of call nodes else self name
"output_names": [name], # output names in case of call nodes else self name
"op": None, # op name in case of call node else None
}
}

def convert_to_onnx(self, func):
""" Traverse Relay graph and generate a ONNX model"""
Expand Down
35 changes: 35 additions & 0 deletions tests/python/contrib/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,38 @@ def verify_tuple_types(dshape, indices_or_sections, axis=None, dtype = "float32"
verify_tuple_types((5, 5, 2, 2), [1, 3, 4], axis=0)
verify_tuple_types((5, 5, 2, 2), [1, 3, 4], axis=1)

def test_layout_transform():
def verify_layout_transform(dshape, src_layout, dst_layout, dtype="float32"):
x = relay.var("x", relay.ty.TensorType(dshape, dtype))
y = relay.layout_transform(x, src_layout, dst_layout)
func = relay.Function([x], y)
x_data = np.random.uniform(size=dshape).astype(dtype)
verify_results(func, [x_data], 'test_layout_transform', rtol=1e-5, atol=1e-5)

verify_layout_transform((1, 3, 8, 8), 'NCHW', 'NHWC')
verify_layout_transform((1, 8, 8, 3), 'NHWC', 'NCHW')

def test_clip():
def verify_clip(dshape, a_min, a_max, dtype="float32"):
x = relay.var("x", relay.ty.TensorType(dshape, dtype))
y = relay.clip(x, a_min, a_max)
func = relay.Function([x], y)
x_data = np.random.uniform(size=dshape).astype(dtype)
verify_results(func, [x_data], 'test_clip', rtol=1e-5, atol=1e-5)

verify_clip((5, 5, 2, 5), 0, 0.2)
verify_clip((5, 5, 2, 5), 0.2, 0.5)

def test_expand_dims():
def verify_expand_dims(dshape, axis, num_newaxis, dtype="float32"):
x = relay.var("x", relay.ty.TensorType(dshape, dtype))
y = relay.expand_dims(x, axis, num_newaxis)
func = relay.Function([x], y)
x_data = np.random.uniform(size=dshape).astype(dtype)
verify_results(func, [x_data], 'test_expand_dims', rtol=1e-5, atol=1e-5)

verify_expand_dims((1, 1001), 0, 2)
verify_expand_dims((1, 1, 1001), 2, 2)

if __name__ == '__main__':
test_add()
Expand All @@ -469,3 +501,6 @@ def verify_tuple_types(dshape, indices_or_sections, axis=None, dtype = "float32"
test_cmp_type()
test_binary_op()
test_tuple_types()
test_layout_transform()
test_clip()
test_expand_dims()

0 comments on commit 9aa69e2

Please sign in to comment.