Skip to content

Commit

Permalink
[TFLite] Convert TFLite NCHW to NHWC (#3141)
Browse files Browse the repository at this point in the history
* Convert TFLite NCHW to NHWC

* Minor comment fix
  • Loading branch information
FrozenGene authored and srkreddy1238 committed May 22, 2019
1 parent 4b1d3d8 commit b63267b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 221 deletions.
120 changes: 15 additions & 105 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,44 +209,10 @@ def convert_reshape(self, op):
reshape_options = ReshapeOptions()
reshape_options.Init(op_options.Bytes, op_options.Pos)
target_shape = reshape_options.NewShapeAsNumpy()
input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())

in_expr = self.get_expr(input_tensor_idx)

if input_shape_length in (1, 2):
# The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing.
pass
elif input_shape_length == 3:
# convert N C H*W to N H*W C
in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# convert input to N H W C, then reshape to target shape,
# finally convert back if necessary
in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
else:
msg = 'Input shape length {} for operator Reshape is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))

out = _op.reshape(in_expr, newshape=tuple(target_shape))

# The rule is channel first.
# 1: N*H*W*C
# 2: N*H*W, C
# 3: N H W C, reshape to N H*W C, transpose to N C H*W
# 4: N H W C, transpose to N C H W
# add more if we need target shapes in future
if len(target_shape) == 1 or len(target_shape) == 2:
pass
elif len(target_shape) == 3:
out = _op.transpose(out, axes=(0, 2, 1))
elif len(target_shape) == 4:
out = _op.transpose(out, axes=(0, 3, 1, 2))
else:
raise tvm.error.OpAttributeInvalid(
'Length of target shape must be between 1 and 5 for operator Reshape.')

return out

def convert_softmax(self, op):
Expand All @@ -269,7 +235,7 @@ def convert_softmax(self, op):
return out

def convert_concatenation(self, op):
""" convert TFLite concatenation"""
"""Convert TFLite concatenation"""
try:
from tflite.Operator import Operator
from tflite.ConcatenationOptions import ConcatenationOptions
Expand All @@ -292,15 +258,6 @@ def convert_concatenation(self, op):
concatenation_options.Init(op_options.Bytes, op_options.Pos)
concatenation_axis = concatenation_options.Axis()
fused_activation_fn = concatenation_options.FusedActivationFunction()
input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy())

# TFLite is N H W C, our layout is N C H W
if input_shape_length <= 4:
axis_convert_map = [0] + list(range(2, input_shape_length)) + [1]
concatenation_axis = axis_convert_map[concatenation_axis]
else:
raise NotImplementedError("Not support input shape length {} of concatenatio : "
.format(str(input_shape_length)))

# with axis in N H W C
out = _op.concatenate(in_exprs, axis=concatenation_axis)
Expand Down Expand Up @@ -336,20 +293,6 @@ def convert_add(self, op):
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
dtype=rhs_type_str)

# In this case, we have to be careful about formatting.
input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy())
if input_shape_length in (1, 2):
pass
elif input_shape_length == 3:
# N H*W C to N C H*W
rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# N H W C to N C H W
rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2))
else:
msg = 'Input shape length {} for operator ADD is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))

out = _op.add(lhs_expr, rhs_expr)
return out

Expand Down Expand Up @@ -440,46 +383,10 @@ def convert_squeeze(self, op):
squeeze_options = SqueezeOptions()
squeeze_options.Init(op_options.Bytes, op_options.Pos)
squeeze_axis = squeeze_options.SqueezeDimsAsNumpy()
input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())
output_shape_length = len(output_tensors[0].tensor.ShapeAsNumpy())

in_expr = self.get_expr(input_tensor_idx)

# TFLite is N H W C, our layout is N C H W
if input_shape_length in (1, 2):
# The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing.
pass
elif input_shape_length == 3:
# convert N C H*W to N H*W C
in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# convert input to N H W C, then reshape to target shape,
# finally convert back if necessary
in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
else:
msg = 'Input shape length {} for operator Squeeze is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))

out = _op.squeeze(in_expr, axis=tuple(squeeze_axis))

# The rule is channel first.
# 1: N*H*W*C
# 2: N*H*W, C
# 3: N H W C, reshape to N H*W C, transpose to N C H*W
# 4: N H W C, transpose to N C H W
# add more if we need target shapes in future
if output_shape_length in (1, 2):
pass
elif output_shape_length == 3:
out = _op.transpose(out, axes=(0, 2, 1))
elif output_shape_length == 4:
out = _op.transpose(out, axes=(0, 3, 1, 2))
else:
msg = 'Output shape length {} for operator Squeeze is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length))

return out

def convert_fused_activation_function(self, in_expr, fused_activation_fn):
Expand Down Expand Up @@ -562,13 +469,16 @@ def convert_conv(self, op, conv_type):
params = {'kernel_size': [kernel_h, kernel_w],
'strides': [stride_h, stride_w],
'dilation': [dilation_h, dilation_w],
'padding': [0, 0]}
'padding': [0, 0],
'data_layout': 'NHWC'}

if is_depthwise_conv:
params['channels'] = int(in_channels * multiplier)
params['groups'] = int(in_channels)
params['kernel_layout'] = 'HWOI'
else:
params['channels'] = int(output_channels)
params['kernel_layout'] = 'HWIO'

# weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
Expand All @@ -578,12 +488,9 @@ def convert_conv(self, op, conv_type):
in_expr = self.get_expr(input_tensor_idx)
weight_value = self.get_tensor_value(weight_tensor)

if is_depthwise_conv:
# TFLite is M KH KW IC, we require IC M KH KW
weight_value = weight_value.transpose((3, 0, 1, 2))
else:
# TFLite is OC KH KW IC, we require OC IC KH kW
weight_value = weight_value.transpose((0, 3, 1, 2))
# TFLite is OC/M KH KW IC, we require KH KW IC OC/M
# M means multiplier in depthwise convolution
weight_value = weight_value.transpose((1, 2, 3, 0))

weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)

Expand All @@ -592,9 +499,10 @@ def convert_conv(self, op, conv_type):
elif padding == Padding.SAME:
pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (0, 0),
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
(pad_top, pad_bottom),
(pad_left, pad_right)))
(pad_left, pad_right),
(0, 0)))
else:
raise tvm.error.OpAttributeUnimplemented(
'Padding format {} is not supported for operator Conv.'.format(padding))
Expand All @@ -610,7 +518,8 @@ def convert_conv(self, op, conv_type):
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
dtype=bias_tensor_type_str)
out = _op.nn.bias_add(out, bias_expr)
channel_axis = 3
out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)

# If we have fused activations
if fused_activation_fn != ActivationFunctionType.NONE:
Expand Down Expand Up @@ -648,7 +557,8 @@ def convert_pool2d(self, op, pool_type):

params = {'pool_size': (filter_h, filter_w),
'strides': (stride_h, stride_w),
'padding': [0, 0]}
'padding': [0, 0],
'layout': 'NHWC'}

in_expr = self.get_expr(input_tensor_idx)

Expand Down
Loading

0 comments on commit b63267b

Please sign in to comment.