Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microNPU] Add unary elementwise operator infrastructure with ABS #9530

Merged
merged 2 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,96 @@ def __call__(self, *args, **kwargs):
pass


class UnaryElementwiseRewriter(DFPatternCallback):
"""
Convert ethosu unary elementwise composite function to
ethosu_unary_elementwise operators
"""

def __init__(self, params_class: Type, pattern: CallPattern):
super().__init__(require_type=True)
self.params_class = params_class
self.pattern = pattern

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = self.params_class(post.op.body)
params.ifm.tensor = post.args[0]

if str(params.ofm.layout) != "NHWC":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC we agreed to remove these checks, please ignore if not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the other operators still have these checks so maybe we can remove them in a follow-up patch for all of the operators (or to be precise, I think we need to add some guards into the is_valid() checks to prevent attempting offloading graphs with non-NHWC layout...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, I was referring to #9508 (comment) but didn't realize the other operators still had this check, I think a follow up would be fine :)

raise UnsupportedLayout(str(params.ofm.layout))

activation_map = {"clip": "CLIP"}
if params.activation:
activation = activation_map[params.activation.op.name]
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0

# We don't yet support activation functions that use LUT.
lut = relay.const([], dtype="int8")

unary_input_shape = params.ifm.shape
# If the input tensor is not 4D, enter reshapes before and after the unary operator
if len(params.ifm.shape) == 4:
unary_input = params.ifm.tensor
else:
pad_size = 4 - len(unary_input_shape)
unary_input_shape = ([1] * pad_size) + unary_input_shape
unary_input = relay.op.reshape(params.ifm.tensor, newshape=unary_input_shape)

ethosu_unary_elementwise = ethosu_ops.ethosu_unary_elementwise(
ifm=unary_input,
lut=lut,
operator_type=params.operator_type,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_channels=unary_input_shape[3],
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
ifm_layout=str(params.ifm.layout),
ofm_layout=str(params.ofm.layout),
)
if len(params.ifm.shape) == 4:
op = ethosu_unary_elementwise
else:
op = relay.op.reshape(ethosu_unary_elementwise, newshape=params.ifm.shape)
return op


class AbsRewriter(UnaryElementwiseRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.AbsParams,
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.AbsParams.composite_name}))(
wildcard()
),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeAbs:
"""This is the pass that wraps the AbsRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(AbsRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand All @@ -765,6 +855,7 @@ def transform_module(
mod = LegalizeMin()(mod)
mod = LegalizeMax()(mod)
mod = LegalizeShl()(mod)
mod = LegalizeAbs()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .pooling import ethosu_pooling
from .binary_elementwise import ethosu_binary_elementwise
from .identity import ethosu_identity
from .unary_elementwise import ethosu_unary_elementwise
163 changes: 163 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/unary_elementwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""Relay operator for unary elementwise operations for Arm(R) Ethos(TM)-U NPU"""
from typing import Optional
import tvm
from tvm.relay.op import _make
from tvm.topi.generic import schedule_injective
from tvm.relay.op.op import OpStrategy
from tvm.relay.op import strategy as _strategy

from ..te import unary_elementwise_compute


def _extract_ethosu_unary_elementwise_params(attrs, args):
"""Get the parameters necessary to construct a ethosu_unary_elementwise compute TE
from a ethosu_unary_elementwise Relay call."""
ifm = args[0]
lut = args[1]
operator_type = attrs.operator_type
ifm_scale = attrs.ifm_scale
ifm_zero_point = attrs.ifm_zero_point
ofm_scale = attrs.ofm_scale
ofm_zero_point = attrs.ofm_zero_point
ofm_channels = attrs.ofm_channels
activation = attrs.activation
clip_min = attrs.clip_min
clip_max = attrs.clip_max
rounding_mode = attrs.rounding_mode
ifm_layout = attrs.ifm_layout
ofm_layout = attrs.ofm_layout

return (
ifm,
lut,
operator_type,
ifm_scale,
ifm_zero_point,
ofm_scale,
ofm_zero_point,
ofm_channels,
activation,
clip_min,
clip_max,
rounding_mode,
ifm_layout,
ofm_layout,
)


@tvm.ir.register_op_attr("contrib.ethosu.unary_elementwise", "FTVMCompute")
def create_ethosu_unary_elementwise_compute(attrs, args, out_type):
"""Create an ethosu_unary_elementwise compute op."""
params = _extract_ethosu_unary_elementwise_params(attrs, args)
op = unary_elementwise_compute(*params)
return [op]


@tvm.ir.register_op_attr("contrib.ethosu.unary_elementwise", "FTVMStrategy")
def unary_elementwise_strategy_ethosu(attrs, inputs, out_type, target):
strategy = OpStrategy()
strategy.add_implementation(
create_ethosu_unary_elementwise_compute,
_strategy.wrap_topi_schedule(schedule_injective),
name="ethosu_unary_elementwise",
)
return strategy


def ethosu_unary_elementwise(
ifm: tvm.relay.Expr,
lut: tvm.relay.Expr,
operator_type: str,
ifm_scale: float,
ifm_zero_point: int,
ofm_scale: float,
ofm_zero_point: int,
ofm_channels: int,
activation: Optional[str] = "NONE",
clip_min: Optional[int] = 0,
clip_max: Optional[int] = 0,
rounding_mode: Optional[str] = "TFL",
ifm_layout: Optional[str] = "NHWC",
ofm_layout: Optional[str] = "NHWC",
) -> tvm.relay.Call:
"""This is a quantized unary elementwise operation as supported by the
NPU. It accepts either NHWC or NHCWB16 format for the input data.

Parameters
----------
ifm : tvm.relay.Expr
The Input Feature Map tensor (IFM).
lut : tvm.relay.Expr
The look-up table values to use if activation = "LUT".
operator_type: str
The type of the unary elementwise operator.
"ABS"
ifm_scale : float
The quantization scale for the Input Feature Map tensor.
ifm_zero_point : int
The quantization zero point for the Input Feature Map tensor.
ofm_scale : float
The quantization scale for the Output Feature Map tensor.
ofm_zero_point : int
The quantization zero point for the Output Feature Map tensor.
ofm_channels : int
The number of OFM channels.
activation : str, optional
The activation function to use.
"NONE" - no activation function.
"CLIP" - clip the output between clip_min and clip_max.
"TANH" - tanh activation function.
"SIGMOID" - sigmoid activation function.
"LUT" - use a look-up table to perform the activation function.
clip_min : int, optional
The minimum clipping value if activation = "CLIP".
clip_max : int, optional
The maximum clipping value if activation = "CLIP".
rounding_mode : str, optional
The rounding mode to apply to the Output Feature Map tensor.
"TFL" - Tensorflow Lite rounding scheme.
"TRUNCATE" - Truncate towards zero.
"NATURAL" - Round to nearest value, with x.5 rounded up towards +infinity.
ifm_layout : str, optional
The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
ofm_layout : str, optional
The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16".

Returns
-------
out : tvm.relay.Call
A call to the ethosu_binary_elementwise op.
"""
return _make.ethosu_unary_elementwise(
ifm,
lut,
operator_type,
ifm_scale,
ifm_zero_point,
ofm_scale,
ofm_zero_point,
ofm_channels,
activation,
clip_min,
clip_max,
rounding_mode,
ifm_layout,
ofm_layout,
)
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .pooling import *
from .binary_elementwise import *
from .identity import *
from .unary_elementwise import *
Loading