Skip to content

Commit

Permalink
[microNPU] Add unary elementwise operator infrastructure with ABS
Browse files Browse the repository at this point in the history
* Added unary elementwise ABS legalization support and tests
* Added unary_elementwise Relay to TIR lowering and tests
* Added TIR to Vela translation and tests
* Added codegen tests

Co-authored-by: Rishabh Jain <[email protected]>
  • Loading branch information
Rishabh Jain authored and ekalda committed Nov 18, 2021
1 parent 65606c9 commit bfdf1f4
Show file tree
Hide file tree
Showing 17 changed files with 1,142 additions and 0 deletions.
91 changes: 91 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,96 @@ def __call__(self, *args, **kwargs):
pass


class UnaryElementwiseRewriter(DFPatternCallback):
"""
Convert ethosu unary elementwise composite function to
ethosu_unary_elementwise operators
"""

def __init__(self, params_class, pattern):
super().__init__(require_type=True)
self.params_class = params_class
self.pattern = pattern

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = self.params_class(post.op.body)
params.ifm.tensor = post.args[0]

if str(params.ofm.layout) != "NHWC":
raise UnsupportedLayout(str(params.ofm.layout))

activation_map = {"clip": "CLIP"}
if params.activation:
activation = activation_map[params.activation.op.name]
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0

# We don't yet support activation functions that use LUT.
lut = relay.const([], dtype="int8")

unary_input_shape = params.ifm.shape
# If the input tensor is not 4D, enter reshapes before and after the unary operator
if len(params.ifm.shape) == 4:
unary_input = params.ifm.tensor
else:
while len(unary_input_shape) < 4:
unary_input_shape = [1] + unary_input_shape
unary_input = relay.op.reshape(params.ifm.tensor, newshape=unary_input_shape)

ethosu_unary_elementwise = ethosu_ops.ethosu_unary_elementwise(
ifm=unary_input,
lut=lut,
operator_type=params.operator_type,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_channels=unary_input_shape[3],
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
ifm_layout=str(params.ifm.layout),
ofm_layout=str(params.ofm.layout),
)
if len(params.ifm.shape) == 4:
op = ethosu_unary_elementwise
else:
op = relay.op.reshape(ethosu_unary_elementwise, newshape=params.ifm.shape)
return op


class AbsRewriter(UnaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.AbsParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.AbsParams.composite_name}))(
wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeAbs:
"""This is the pass that wraps the AbsRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(AbsRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand All @@ -765,6 +855,7 @@ def transform_module(
mod = LegalizeMin()(mod)
mod = LegalizeMax()(mod)
mod = LegalizeShl()(mod)
mod = LegalizeAbs()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .pooling import ethosu_pooling
from .binary_elementwise import ethosu_binary_elementwise
from .identity import ethosu_identity
from .unary_elementwise import ethosu_unary_elementwise
153 changes: 153 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/unary_elementwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""Relay operator for unary elementwise operations for Arm(R) Ethos(TM)-U NPU"""
import tvm
from tvm.relay.op import _make
from tvm.topi.generic import schedule_injective
from tvm.relay.op.op import OpStrategy
from tvm.relay.op import strategy as _strategy

from ..te import unary_elementwise_compute


def _extract_ethosu_unary_elementwise_params(attrs, args):
"""Get the parameters necessary to construct a ethosu_unary_elementwise compute TE
from a ethosu_unary_elementwise Relay call."""
ifm = args[0]
lut = args[1]
operator_type = attrs.operator_type
ifm_scale = attrs.ifm_scale
ifm_zero_point = attrs.ifm_zero_point
ofm_scale = attrs.ofm_scale
ofm_zero_point = attrs.ofm_zero_point
ofm_channels = attrs.ofm_channels
activation = attrs.activation
clip_min = attrs.clip_min
clip_max = attrs.clip_max
ifm_layout = attrs.ifm_layout
ofm_layout = attrs.ofm_layout

return (
ifm,
lut,
operator_type,
ifm_scale,
ifm_zero_point,
ofm_scale,
ofm_zero_point,
ofm_channels,
activation,
clip_min,
clip_max,
ifm_layout,
ofm_layout,
)


@tvm.ir.register_op_attr("contrib.ethosu.unary_elementwise", "FTVMCompute")
def create_ethosu_unary_elementwise_compute(attrs, args, out_type):
"""Create an ethosu_unary_elementwise compute op."""
params = _extract_ethosu_unary_elementwise_params(attrs, args)
op = unary_elementwise_compute(*params)
return [op]


@tvm.ir.register_op_attr("contrib.ethosu.unary_elementwise", "FTVMStrategy")
def unary_elementwise_strategy_ethosu(attrs, inputs, out_type, target):
strategy = OpStrategy()
strategy.add_implementation(
create_ethosu_unary_elementwise_compute,
_strategy.wrap_topi_schedule(schedule_injective),
name="ethosu_unary_elementwise",
)
return strategy


def ethosu_unary_elementwise(
ifm: tvm.relay.Expr,
lut: tvm.relay.Expr,
operator_type: str,
ifm_scale: float,
ifm_zero_point: int,
ofm_scale: float,
ofm_zero_point: int,
ofm_channels: int,
activation: str = "NONE",
clip_min: int = 0,
clip_max: int = 0,
ifm_layout: str = "NHWC",
ofm_layout: str = "NHWC",
) -> tvm.relay.Call:
"""This is a quantized unary elementwise operation as supported by the
NPU. It accepts either NHWC or NHCWB16 format for the input data.
Parameters
----------
ifm : tvm.relay.Expr
The Input Feature Map tensor (IFM).
lut : tvm.relay.Expr
The look-up table values to use if activation = "LUT".
operator_type: str
The type of the unary elementwise operator.
"ABS"
ifm_scale : float
The quantization scale for the Input Feature Map tensor.
ifm_zero_point : int
The quantization zero point for the Input Feature Map tensor.
ofm_scale : float
The quantization scale for the Output Feature Map tensor.
ofm_zero_point : int
The quantization zero point for the Output Feature Map tensor.
ofm_channels : int
The number of OFM channels.
activation : str, optional
The activation function to use.
"NONE" - no activation function.
"CLIP" - clip the output between clip_min and clip_max.
"TANH" - tanh activation function.
"SIGMOID" - sigmoid activation function.
"LUT" - use a look-up table to perform the activation function.
clip_min : int, optional
The minimum clipping value if activation = "CLIP".
clip_max : int, optional
The maximum clipping value if activation = "CLIP".
ifm_layout : str, optional
The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
ofm_layout : str, optional
The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16".
Returns
-------
out : tvm.relay.Call
A call to the ethosu_binary_elementwise op.
"""
return _make.ethosu_unary_elementwise(
ifm,
lut,
operator_type,
ifm_scale,
ifm_zero_point,
ofm_scale,
ofm_zero_point,
ofm_channels,
activation,
clip_min,
clip_max,
ifm_layout,
ofm_layout,
)
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .pooling import *
from .binary_elementwise import *
from .identity import *
from .unary_elementwise import *
119 changes: 119 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-argument
"""Tensor Expressions for unary_elementwise for the NPU"""

from tvm import te
from .dma import dma_ofm_compute, dma_ifm_compute


def unary_elementwise_compute(
ifm: te.Tensor,
lut: te.Tensor,
operator_type: str,
ifm_scale: float,
ifm_zero_point: int,
ofm_scale: float,
ofm_zero_point: int,
ofm_channels: int,
activation: str,
clip_min: int,
clip_max: int,
ifm_layout: str,
ofm_layout: str,
) -> te.Tensor:
"""A compute operator representing the capabilities of unary_elementwise for the NPU.
Parameters
----------
ifm : te.Tensor
The Input Feature Map tensor (IFM).
lut : te.Tensor
The look-up table values to use if activation = "LUT".
operator_type: str
The type of the unary elementwise operator.
"ABS"
ifm_scale : float
The quantization scale for the Input Feature Map tensor.
ifm_zero_point : int
The quantization zero point for the Input Feature Map tensor.
ofm_scale : float
The quantization scale for the Output Feature Map tensor.
ofm_zero_point : int
The quantization zero point for the Output Feature Map tensor.
ofm_channels : int
The number of OFM channels.
activation : str
The activation function to use.
"NONE" - no activation function.
"CLIP" - clip the output between clip_min and clip_max.
"TANH" - tanh activation function.
"SIGMOID" - sigmoid activation function.
"LUT" - use a look-up table to perform the activation function.
clip_min : int
The minimum clipping value if activation = "CLIP".
clip_max : int
The maximum clipping value if activation = "CLIP".
ifm_layout : str, optional
The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
ofm_layout : str, optional
The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16".
Returns
-------
te.Tensor
The OFM tensor.
"""
assert ifm.shape[0] == 1
assert ifm_layout in {"NHWC", "NHCWB16"}
assert ofm_layout in {"NHWC", "NHCWB16"}

# Changing the ifm and ofm scale to conform with that expected by Vela API
ofm_scale = ifm_scale / ofm_scale
ifm_scale = 1.0

# Compute operation for the IFM DMA pipeline
dmaed_ifm = dma_ifm_compute(
ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, (0, 0, 0, 0)
)

# Unary elementwise compute operation
ofm_height = dmaed_ifm.shape[1]
ofm_width = dmaed_ifm.shape[2]

unary_elementwise_attrs = {
"op": "ethosu_unary_elementwise",
"operator_type": operator_type,
"activation": activation,
"clip_min": clip_min,
"clip_max": clip_max,
}

operators = {"ABS": te.abs}

unary_elementwise = te.compute(
(1, ofm_height, ofm_width, ofm_channels),
lambda nn, hh, ww, cc: operators[operator_type](
dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype)
),
name="ethosu_unary_elementwise",
attrs=unary_elementwise_attrs,
)

# Compute operation for the OFM DMA pipeline
return dma_ofm_compute(unary_elementwise, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels)
Loading

0 comments on commit bfdf1f4

Please sign in to comment.