Skip to content

Commit

Permalink
[microNPU] Add support for nearest neighbor and bilinear upsampling
Browse files Browse the repository at this point in the history
Adds support for 2x2 nearest neighbor and bilinear upsampling. In the
case of bilinear upsampling with align_corners set to true, the
upsampling size must be `2*input_size - 1` (as opposed to `2*input_size`).

Change-Id: I95d215eabfaac983629dcdedcda2b90efb8e0ddf
  • Loading branch information
lhutton1 committed Jan 5, 2022
1 parent 92eeef6 commit a8fecc9
Show file tree
Hide file tree
Showing 18 changed files with 767 additions and 17 deletions.
109 changes: 109 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,114 @@ def __call__(self, *args, **kwargs):
pass


class Resize2dRewriter(DFPatternCallback):
"""
Convert ethos-u.resize2d composite function to an equivalent operation that
performs the relevant upsampling operation.
Case 1: Upsample factor of 1 i.e. no upsampling:
add scalar of 0 - do nothing.
Case 2: Nearest neighbor upsampling:
identity with 2x2 upsampling.
Case 3: Bilinear upsampling:
2x2 average pool with 2x2 nearest neighbor upsampling on the input.
"""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.Resize2dParams.composite_name})
)(wildcard())

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = ethosu_patterns.Resize2dParams(post.op.body)
params.ifm.tensor = post.args[0]

lut = relay.const([], "int8")
ifm_shape = params.ifm.shape
in_channels = ifm_shape[-1]
upscaled_shape = np.array([int(param) for param in params.size])
reduced_op = params.ifm.tensor

# No upsampling to be done
if (upscaled_shape == np.array(params.ifm.shape[1:3])).all():
return ethosu_ops.ethosu_identity(
ifm=reduced_op,
lut=lut,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
)

if params.method == "nearest_neighbor":
reduced_op = ethosu_ops.ethosu_identity(
ifm=reduced_op,
lut=lut,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
upscale="NEAREST",
upscale_height=upscaled_shape[0],
upscale_width=upscaled_shape[1],
)
elif params.method == "linear":
# For align_corners use VALID padding, otherwise use SAME padding
if params.coordinate_transformation_mode == "align_corners":
padding = [0, 0, 0, 0]
else:
ypad = Resize2dRewriter.get_required_padding(ifm_shape[1])
xpad = Resize2dRewriter.get_required_padding(ifm_shape[2])
padding = [ypad // 2, xpad // 2, (ypad + 1) // 2, (xpad + 1) // 2]

reduced_op = ethosu_ops.ethosu_pooling(
ifm=reduced_op,
lut=lut,
pooling_type="AVG",
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
pool_shape=[2, 2],
ofm_channels=in_channels,
strides=[1, 1],
padding=padding,
upscale="NEAREST",
upscale_height=upscaled_shape[0],
upscale_width=upscaled_shape[1],
rounding_mode="NATURAL",
)

return reduced_op

@staticmethod
def get_required_padding(input_size: int, pool_size: int = 2) -> int:
"""Gets the amount of padding required needed to achieve
'SAME' padding for a given axis."""
needed_input = (input_size - 1) + pool_size
total_padding = max(0, needed_input - input_size)
return total_padding


@ir.transform.module_pass(opt_level=1)
class LegalizeResize2d:
"""This is the pass that wraps Resize2dRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(Resize2dRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand Down Expand Up @@ -1271,6 +1379,7 @@ def transform_module(
mod = LegalizeMean()(mod)
mod = LegalizeConcat()(mod)
mod = LegalizeSigmoid()(mod)
mod = LegalizeResize2d()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
41 changes: 39 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/op/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,20 @@ def create_ethosu_identity_compute(attrs, args, out_type):
ofm_scale = attrs.ofm_scale
ofm_zero_point = attrs.ofm_zero_point
activation = attrs.activation
upscale = attrs.upscale
upscale_height = attrs.upscale_height
upscale_width = attrs.upscale_width
op = identity_compute(
ifm, lut, ifm_scale, ifm_zero_point, ofm_scale, ofm_zero_point, activation
ifm,
lut,
ifm_scale,
ifm_zero_point,
ofm_scale,
ofm_zero_point,
activation,
upscale,
upscale_height,
upscale_width,
)
return [op]

Expand All @@ -61,6 +73,9 @@ def ethosu_identity(
ofm_scale: float = 1,
ofm_zero_point: int = 0,
activation: str = "NONE",
upscale: str = "NONE",
upscale_height: int = 0,
upscale_width: int = 0,
) -> tvm.relay.Call:
"""The Identity operator that runs on the NPU.
Expand All @@ -87,12 +102,34 @@ def ethosu_identity(
"TANH" - tanh activation function.
"SIGMOID" - sigmoid activation function.
"LUT" - use a look-up table to perform the activation function.
upscale: str, optional
The 2x2 upscaling mode to apply to the Input Feature Map tensor.
"NONE" - no upscaling.
"NEAREST" - upscale using nearest neighbour.
"ZEROS" - upscale using zeros.
upscale_height: int, optional
The height of the Output Feature Map after applying upscaling. A value of
0 means the height of the Input Feature Map will be used. This parameter
has no effect when upscale is "NONE".
upscale_width: int, optional
The width of the Output Feature Map after applying upscaling. A value of
0 means the width of the Input Feature Map will be used. This parameter
has no effect when upscale is "NONE".
Returns
-------
out : tvm.relay.Call
A call to the ethosu_identity op.
"""
return _make.ethosu_identity(
ifm, lut, ifm_scale, ifm_zero_point, ofm_scale, ofm_zero_point, activation
ifm,
lut,
ifm_scale,
ifm_zero_point,
ofm_scale,
ofm_zero_point,
activation,
upscale,
upscale_height,
upscale_width,
)
16 changes: 16 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def _extract_ethosu_pooling_params(attrs, args):
clip_max = attrs.clip_max
rounding_mode = attrs.rounding_mode
upscale = attrs.upscale
upscale_height = attrs.upscale_height
upscale_width = attrs.upscale_width
ifm_layout = attrs.ifm_layout
ofm_layout = attrs.ofm_layout

Expand All @@ -66,6 +68,8 @@ def _extract_ethosu_pooling_params(attrs, args):
clip_max,
rounding_mode,
upscale,
upscale_height,
upscale_width,
ifm_layout,
ofm_layout,
)
Expand Down Expand Up @@ -107,6 +111,8 @@ def ethosu_pooling(
clip_max: int = 0,
rounding_mode: str = "TFL",
upscale: str = "NONE",
upscale_height: int = 0,
upscale_width: int = 0,
ifm_layout: str = "NHWC",
ofm_layout: str = "NHWC",
) -> tvm.relay.Call:
Expand Down Expand Up @@ -159,6 +165,14 @@ def ethosu_pooling(
"NONE" - no upscaling.
"NEAREST" - upscale using nearest neighbour.
"ZEROS" - upscale using zeros.
upscale_height: int, optional
The height of the Output Feature Map after applying upscaling. A value of
0 means the height of the Input Feature Map will be used. This parameter
has no effect when upscale is "NONE".
upscale_width: int, optional
The width of the Output Feature Map after applying upscaling. A value of
0 means the width of the Input Feature Map will be used. This parameter
has no effect when upscale is "NONE".
ifm_layout : str, optional
The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
ofm_layout : str, optional
Expand Down Expand Up @@ -186,6 +200,8 @@ def ethosu_pooling(
clip_max,
rounding_mode,
upscale,
upscale_height,
upscale_width,
ifm_layout,
ofm_layout,
)
29 changes: 27 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/te/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Tensor Expression for identity"""
from tvm import te
from .dma import read_compute, write_compute
from ..util import upscale_ofm


def identity_compute(
Expand All @@ -28,6 +29,9 @@ def identity_compute(
ofm_scale: float,
ofm_zero_point: int,
activation: str,
upscale: str,
upscale_height: int,
upscale_width: int,
) -> te.Tensor:
"""A compute operator for the NPU identity operator.
Expand All @@ -51,6 +55,19 @@ def identity_compute(
"TANH" - tanh activation function.
"SIGMOID" - sigmoid activation function.
"LUT" - use a look-up table to perform the activation function.
upscale: str, optional
The 2x2 upscaling mode to apply to the Input Feature Map tensor.
"NONE" - no upscaling.
"NEAREST" - upscale using nearest neighbour.
"ZEROS" - upscale using zeros.
upscale_height: int, optional
The height of the Output Feature Map after applying upscaling. A value of
0 means the height of the Input Feature Map will be used. This parameter
has no effect when upscale is "NONE".
upscale_width: int, optional
The width of the Output Feature Map after applying upscaling. A value of
0 means the width of the Input Feature Map will be used. This parameter
has no effect when upscale is "NONE".
Returns
-------
Expand All @@ -59,7 +76,15 @@ def identity_compute(
"""
dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale)
id_attrs = {"op": "ethosu_identity", "activation": activation}
id_attrs = {
"op": "ethosu_identity",
"activation": activation,
"upscale": upscale,
}

ofm_shape = (
upscale_ofm(ifm.shape, upscale_height, upscale_width) if len(ifm.shape) == 4 else ifm.shape
)

has_lut = activation in ("TANH", "LUT", "SIGMOID")

Expand All @@ -71,7 +96,7 @@ def identity_compute(
id_attrs["lut"] = lut

identity = te.compute(
ifm.shape,
ofm_shape,
lambda *i: (dmaed_ifm(*i) + lut_expr).astype(ifm.dtype),
name="ethosu_identity",
attrs=id_attrs,
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from tvm import te
from .dma import dma_ofm_compute, dma_ifm_compute
from ..util import upscale_ofm


def pooling_compute(
Expand All @@ -39,6 +40,8 @@ def pooling_compute(
clip_max: int,
rounding_mode: str,
upscale: str,
upscale_height: int,
upscale_width: int,
ifm_layout: str,
ofm_layout: str,
) -> te.Tensor:
Expand Down Expand Up @@ -108,6 +111,8 @@ def pooling_compute(
# Pooling compute operation
ofm_height = (dmaed_ifm.shape[1] - pool_shape_h) // stride_h + 1
ofm_width = (dmaed_ifm.shape[2] - pool_shape_w) // stride_w + 1
ofm_shape = (1, ofm_height, ofm_width, ofm_channels)
ofm_shape = upscale_ofm(ofm_shape, upscale_height, upscale_width)
rh = te.reduce_axis((0, pool_shape_h), name="ry")
rw = te.reduce_axis((0, pool_shape_w), name="rx")

Expand All @@ -133,7 +138,7 @@ def pooling_compute(
pooling_attrs["lut"] = lut

pooling = te.compute(
(1, ofm_height, ofm_width, ofm_channels),
ofm_shape,
lambda nn, hh, ww, cc: te.max(
(dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc) + lut_expr).astype(
ifm.dtype
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_identity_params(
pool_shape=SerialKernel(1, 1, 1, 1, 1, 1),
padding=SerialPadding(0, 0, 0, 0),
activation=serial_activation,
upscale="NONE",
upscale=attrs["upscale"],
rounding_mode="TFL",
),
output_pointer,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_pooling_params(
padding=serial_padding,
activation=serial_activation,
rounding_mode=attrs["rounding_mode"],
upscale="NONE",
upscale=attrs["upscale"],
),
output_pointer,
replace_pointer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _create_npu_op_conv2d(
_convert_clip_bounds(npu_conv2d_op)

npu_conv2d_op.rounding_mode = _create_npu_rounding_mode(serial_2d_convolution.rounding_mode)
npu_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_convolution.upscale)
npu_conv2d_op.ifm_upscale = _create_npu_resampling_mode(serial_2d_convolution.upscale)
accel_config = vela_api.get_accelerator_config()
weights_shape_ohwi = [
npu_conv2d_op.ofm.shape.depth,
Expand Down Expand Up @@ -505,7 +505,7 @@ def _create_npu_op_depthwise_conv2d(serial_2d_depthwise):
npu_depthwise_conv2d_op.rounding_mode = _create_npu_rounding_mode(
serial_2d_depthwise.rounding_mode
)
npu_depthwise_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_depthwise.upscale)
npu_depthwise_conv2d_op.ifm_upscale = _create_npu_resampling_mode(serial_2d_depthwise.upscale)
target_accel_config = vela_api.get_accelerator_config()
block_config = vela_api.get_optimal_block_config(npu_depthwise_conv2d_op, target_accel_config)
npu_depthwise_conv2d_op.block_config = block_config
Expand Down Expand Up @@ -736,7 +736,7 @@ def _create_npu_op_pooling(serial_pooling: spec.SerialPooling):
_convert_clip_bounds(npu_pooling_op)

npu_pooling_op.rounding_mode = _create_npu_rounding_mode(serial_pooling.rounding_mode)
npu_pooling_op.upscale = _create_npu_resampling_mode(serial_pooling.upscale)
npu_pooling_op.ifm_upscale = _create_npu_resampling_mode(serial_pooling.upscale)

target_accel_config = vela_api.get_accelerator_config()
block_config = vela_api.get_optimal_block_config(npu_pooling_op, target_accel_config)
Expand Down
Loading

0 comments on commit a8fecc9

Please sign in to comment.