From c0a03b7848c01f0e7bb8d3f2f5348d06778536ca Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 11 Mar 2022 14:52:03 +0000 Subject: [PATCH 1/4] [microNPU] Refactor Relay to TIR hook Refactors the Relay to TIR python hook for the NPU so that optimizations can be applied across the whole module and not just functions that will be offloaded to the NPU. A pass `OutlineCompilerFunctions` is introduced to outline NPU functions, which now happens before optimization passes are run (this previously happened after the prim_func had been created). In addition, optimization passes that should only run on NPU functions are now limited to running on outlined functions for the NPU (by checking the "Compiler" attribute). To help avoid code duplication, a helpful decorator `create_npu_function_pass` has been created for python passes that should only run on NPU functions. This refactor helps move a number of passes in the microNPU codegen to use an IRModule -> IRModule philosophy. Change-Id: Icdea9ba43da0157d5ee17529d2b23b761396d112 --- .../relay/backend/contrib/ethosu/codegen.py | 127 +++-- .../relay/backend/contrib/ethosu/legalize.py | 484 ++---------------- .../tvm/relay/backend/contrib/ethosu/util.py | 39 ++ src/relay/backend/contrib/ethosu/codegen.cc | 99 ++-- .../test_ethosu/test_identity_optimizer.py | 15 +- .../test_ethosu/test_layout_optimizer.py | 13 +- .../contrib/test_ethosu/test_legalize.py | 28 +- .../contrib/test_ethosu/test_lut_optimizer.py | 10 +- .../test_outline_compiler_functions.py | 62 +++ 9 files changed, 314 insertions(+), 563 deletions(-) create mode 100644 tests/python/contrib/test_ethosu/test_outline_compiler_functions.py diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index f968d6a1f385..9ef18dd5fee5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -19,7 +19,6 @@ import tvm from tvm import relay -from tvm import ir from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU @@ -112,30 +111,24 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return new_call -@ir.transform.module_pass(opt_level=1, name="LUTsOptimizer") +@util.npu_pass(opt_level=1) class LUTsOptimizer: """Register LUTsOptimizer as a relay pass.""" - def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule: - """Visit relay nodes in the given module. + def transform_npu_function(self, _, func: relay.Function) -> relay.Function: + """Visit relay nodes in the given NPU function. Parameters ---------- func : tvm.relay.function.Function The function to apply the optimization pass for multiple LUTs to. - mod : tvm.IRModule - The module to apply the optimization pass for multiple LUTs to. Returns ------- mod : tvm.IRModule New module with optimized LUTs. """ - assert len(mod.functions.items()) == 1, "Module can only contain one function." - global_var, func = mod.functions.items()[0] - optimized_func = OptimizeLUTs().visit(func) - mod.update_func(global_var, optimized_func) - return mod + return OptimizeLUTs().visit(func) def __call__(self, *args, **kwargs): pass @@ -272,30 +265,27 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return super().visit_call(call) -@ir.transform.module_pass(opt_level=1, name="LayoutOptimizer") +@util.npu_pass(opt_level=1) class LayoutOptimizer: """Register LayoutOptimizer as a Relay pass.""" - OPTIMIZE_OPS = { - "contrib.ethosu.conv2d": op.ethosu_conv2d, - "contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d, - "contrib.ethosu.pooling": op.ethosu_pooling, - "contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise, - "contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise, - } - - def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule: + def transform_npu_function(self, _, func: relay.Function) -> relay.Function: """A pass to optimize the layout of NPU operations. If both the producer and consumer of a tensor are NPU operators, then the layout is converted from NHWC to NHCWB16 as this is the layout NPU uses internally.""" - assert len(mod.functions.items()) == 1, "Module can only contain one function." - global_var, func = mod.functions.items()[0] - analyze = AnalyzeConsumers(self.OPTIMIZE_OPS) + + optimize_ops = { + "contrib.ethosu.conv2d": op.ethosu_conv2d, + "contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d, + "contrib.ethosu.pooling": op.ethosu_pooling, + "contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise, + "contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise, + } + + analyze = AnalyzeConsumers(optimize_ops) analyze.visit(func) - optimized_func = LayoutOptimization(analyze.npu_consumers, self.OPTIMIZE_OPS).visit(func) - mod.update_func(global_var, optimized_func) - return mod + return LayoutOptimization(analyze.npu_consumers, optimize_ops).visit(func) def __call__(self, *args, **kwargs): pass @@ -312,6 +302,48 @@ def IdentityOptimizer(): # pylint: disable=invalid-name return _ffi_api.IdentityOptimizer() +def OutlineCompilerFunctions(compiler_name): # pylint: disable=invalid-name + """Pass that outlines functions given a named Compiler attribute. + + Parameters + ---------- + compiler_name + The name of the compiler to look for and outline. + + Return + ------ + Pass + The module pass. + """ + return _ffi_api.OutlineCompilerFunctions(compiler_name) + + +@util.npu_pass(opt_level=1) +class RelayToTIR: + """Register RelayToTIR pass.""" + + def transform_npu_function(self, _, func: relay.Function) -> relay.Function: + """Lower NPU functions to TIR.""" + # We are currently using copy_constants scheduler In the long run, + # this should be a single intelligent and a composite scheduler + # that can perform scheduling based on user inputs such as + # scratch memory size. + tir_mod, const_dict = lower_to_tir(func, copy_constants()) + + for param in const_dict.keys(): + const_dict[param] = tvm.nd.array(const_dict[param]) + + compiler_name = "ethos-u" + primfunc = tir_mod["main"] + primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"]) + primfunc = primfunc.with_attr("ethos-u.constants", const_dict) + primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name)) + return primfunc + + def __call__(self, *args, **kwargs): + pass + + @tvm._ffi.register_func("relay.ext.ethos-u.constant_updater") def constant_updater(expr, symbol): # pylint: disable=unused-argument """ @@ -322,43 +354,36 @@ def constant_updater(expr, symbol): # pylint: disable=unused-argument return dict() -@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir_func") -def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: +@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir") +def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: """ - This is the hook for python-based lowering of relay function - that gets offloaded to the microNPU. + This is the hook for python-based lowering of a Relay module which lowers NPU + external functions to TIR. Parameters ---------- - ext_func : relay.Function - This is the partitioned relay function + mod : tvm.ir.IRModule + This is the Relay module. Returns ------- - primfunc : tir.PrimFunc - This returns the scheduled PrimFunc + mod : tvm.ir.TRModule + The Relay module with scheduled NPU external functions. """ - assert len(ext_func.params) == 1 - mod = tvm.IRModule() - mod["main"] = ext_func + mod = OutlineCompilerFunctions("ethos-u")(mod) mod = LegalizeEthosU()(mod) mod = LUTsOptimizer()(mod) mod = IdentityOptimizer()(mod) mod = LayoutOptimizer()(mod) mod = relay.transform.InferType()(mod) - # We are currently using copy_constants scheduler In the long run, - # this should be a single intelligent and a composite scheduler - # that can perform scheduling based on user inputs such as - # scratch memory size. - tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants()) - - for param in const_dict.keys(): - const_dict[param] = tvm.nd.array(const_dict[param]) - - primfunc = tir_mod["main"] - primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"]) - primfunc = primfunc.with_attr("ethos-u.constants", const_dict) - return primfunc + + device_contexts = { + gv: "ethos-u" for gv, _ in filter(lambda x: util.is_npu_func(x[1]), mod.functions.items()) + } + mod = mod.with_attr("device_contexts", device_contexts) + mod = RelayToTIR()(mod) + + return mod @tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact") diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 3fdcdb6c24b5..4e2737754be2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -23,7 +23,6 @@ import tvm # type: ignore from tvm import relay -from tvm import ir from tvm.relay.dataflow_pattern import DFPatternCallback # type: ignore from tvm.relay.dataflow_pattern import wildcard from tvm.relay.dataflow_pattern import is_op @@ -127,23 +126,6 @@ def callback( return relay.op.split(split_input, indices_or_sections, axis=axis).astuple() -@ir.transform.module_pass(opt_level=1) -class LegalizeSplit: - """This is the pass that wraps SplitRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(PartitionedSplitRewriter(), func) - func = rewrite(SplitRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - def get_lut_from_func( ifm_scale: float, ifm_zp: int, @@ -244,22 +226,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeTanh: - """This is the pass that wraps TanhRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(TanhRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - def sigmoid_calc_func(x: float) -> float: """Function to calculate the values for sigmoid""" # These limits are inherited from TFLite @@ -286,22 +252,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeSigmoid: - """This is the pass that wraps SigmoidRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(SigmoidRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - def leaky_relu_calc_func(x: float, alpha: float) -> float: """Function to calculate the values for leaky relu.""" return x if x >= 0 else x * alpha @@ -322,22 +272,6 @@ def get_calc_func_params(self, expr: tvm.relay.Expr) -> Dict[str, Any]: return {"alpha": params.alpha} -@ir.transform.module_pass(opt_level=1) -class LegalizeLeakyReLU: - """This is the pass that wraps LeakyReLURewriter.""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(LeakyReLURewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class Conv2DRewriter(DFPatternCallback): """Convert conv2d related composite functions into ethosu_conv2d operators""" @@ -405,22 +339,6 @@ def callback( return ethosu_conv2d -@ir.transform.module_pass(opt_level=1) -class LegalizeConv2D: - """This is the pass that wraps the Conv2DRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(Conv2DRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class Conv2DTransposeRewriter(DFPatternCallback): """Convert conv2d_transpose related composite functions into ethosu_conv2d_transpose operators.""" @@ -486,22 +404,6 @@ def callback( return relay.strided_slice(reduced_op, (0, 0, 0, 0), ofm_shape) -@ir.transform.module_pass(opt_level=1) -class LegalizeConv2DTranspose: - """This is the pass that wraps the Conv2DTransposeRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(Conv2DTransposeRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class DepthwiseConv2DRewriter(DFPatternCallback): """Convert ethosu.qnn_depthwise_conv2d composite functions to ethosu_depthwise_conv2d operators""" @@ -576,22 +478,6 @@ def callback( return ethosu_depthwise_conv2d -@ir.transform.module_pass(opt_level=1) -class LegalizeDepthwiseConv2D: - """This is the pass that wraps the DepthwiseConv2DRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(DepthwiseConv2DRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class PoolingRewriter(DFPatternCallback): """Convert ethosu.avgpool2d and ethosu.maxpool2d composite functions to ethosu_pooling operators""" @@ -658,22 +544,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeMaxPooling: - """This is the pass that wraps the MaxPoolingRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(MaxPoolingRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class AvgPoolingRewriter(PoolingRewriter): def __init__(self): super().__init__( @@ -684,22 +554,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeAvgPooling: - """This is the pass that wraps the AvgPoolingRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(AvgPoolingRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class BinaryElementwiseRewriter(DFPatternCallback): """Convert ethosu binary elementwise composite functions to ethosu_binary_elementwise operators""" @@ -826,22 +680,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeAdd: - """This is the pass that wraps the AddRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(AddRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class SubRewriter(BinaryElementwiseRewriter): def __init__(self): super().__init__( @@ -852,22 +690,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeSub: - """This is the pass that wraps the SubRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(SubRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class MulRewriter(BinaryElementwiseRewriter): def __init__(self): super().__init__( @@ -878,22 +700,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeMul: - """This is the pass that wraps the MulRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(MulRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class MinRewriter(BinaryElementwiseRewriter): def __init__(self): super().__init__( @@ -904,22 +710,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeMin: - """This is the pass that wraps the MinRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(MinRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class MaxRewriter(BinaryElementwiseRewriter): def __init__(self): super().__init__( @@ -930,22 +720,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeMax: - """This is the pass that wraps the MaxRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(MaxRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class ShlRewriter(BinaryElementwiseRewriter): def __init__(self): super().__init__( @@ -956,22 +730,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeShl: - """This is the pass that wraps the ShlRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(ShlRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class StridedSliceRewriter(DFPatternCallback): """This pass brings the strided slice out of the partitioned function""" @@ -1005,22 +763,6 @@ def callback( return strided_slice -@ir.transform.module_pass(opt_level=1) -class LegalizeStridedSlice: - """This is the pass that wraps StridedSliceRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(StridedSliceRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class ReshapeRewriter(DFPatternCallback): """This pass brings the reshape out of the partitioned function""" @@ -1039,22 +781,6 @@ def callback( return relay.op.reshape(reshape_input, newshape=new_shape) -@ir.transform.module_pass(opt_level=1) -class LegalizeReshape: - """This is the pass that wraps ReshapeRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(ReshapeRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class NoOpRewriter(DFPatternCallback): """This pass adds an idenity operator to reshape and strided slice to avoid a no op without a consumer""" @@ -1073,22 +799,6 @@ def callback( return ethosu_ops.ethosu_identity(ifm=post, lut=relay.const([], dtype="int8")) -@ir.transform.module_pass(opt_level=1) -class LegalizeNoOps: - """This is the pass that wraps RewriteNoOps""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(NoOpRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class UnaryElementwiseRewriter(DFPatternCallback): """ Convert ethosu unary elementwise composite function to @@ -1160,22 +870,6 @@ def __init__(self): ) -@ir.transform.module_pass(opt_level=1) -class LegalizeAbs: - """This is the pass that wraps the AbsRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(AbsRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class MeanRewriter(DFPatternCallback): """Convert ethosu.mean composite functions to to an equivalent legalization: - Case 1 (axis == [1, 2] and keepsdims == True): @@ -1324,22 +1018,6 @@ def callback( return reduced_op -@ir.transform.module_pass(opt_level=1) -class LegalizeMean: - """This is the pass that wraps the MeanRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(MeanRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class ConcatRewriter(DFPatternCallback): """The newer versions of TFLite converters return a concatenate operator that concatenates tensors with same QNN params (if the QNN params of tensors were initially different, @@ -1366,22 +1044,6 @@ def callback( return concat -@ir.transform.module_pass(opt_level=1) -class LegalizeConcat: - """This is the pass that wraps ConcatRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(ConcatRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class RequantizeRewriter(DFPatternCallback): """Convert ethos-u.requantize composite function to an identity operation.""" @@ -1409,22 +1071,6 @@ def callback( ) -@ir.transform.module_pass(opt_level=1) -class LegalizeRequantize: - """This is the pass that wraps RequantizeRewriter.""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(RequantizeRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class Resize2dRewriter(DFPatternCallback): """ Convert ethos-u.resize2d composite function to an equivalent operation that @@ -1504,22 +1150,6 @@ def get_required_padding(input_size: int, pool_size: int = 2) -> int: return total_padding -@ir.transform.module_pass(opt_level=1) -class LegalizeResize2d: - """This is the pass that wraps Resize2dRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(Resize2dRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class ExpandDimsRewriter(DFPatternCallback): """Legalize expand dims to a reshape operator.""" @@ -1536,22 +1166,6 @@ def callback( return relay.op.reshape(post.args[0], newshape=params.output.shape) -@ir.transform.module_pass(opt_level=1) -class LegalizeExpandDims: - """This is the pass that wraps ExpandDimsRewriter.""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(ExpandDimsRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class SqueezeRewriter(DFPatternCallback): """Legalize squeeze to a reshape operator.""" @@ -1568,22 +1182,6 @@ def callback( return relay.op.reshape(post.args[0], newshape=params.output.shape) -@ir.transform.module_pass(opt_level=1) -class LegalizeSqueeze: - """This is the pass that wraps SqueezeRewriter.""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(SqueezeRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - class FullyConnectedRewriter(DFPatternCallback): """Legalize Fully Connected (with bias and clip) to an NPU operator""" @@ -1654,62 +1252,50 @@ def callback(self, pre, post, node_map): return ethosu_fc -@ir.transform.module_pass(opt_level=1) -class LegalizeFullyConnected: - """This is the pass that wraps the FullyConnectedRewriter""" - - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: - for global_var, func in mod.functions.items(): - func = rewrite(FullyConnectedRewriter(), func) - mod.update_func(global_var, func) - return mod - - def __call__(self, *args, **kwargs): - pass - - -@ir.transform.module_pass(opt_level=1) +@util.npu_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation in a way such that the operations are replaced with hardware/codegen supported operations. """ - def transform_module( - self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext - ) -> tvm.ir.IRModule: + def transform_npu_function(self, _, func: relay.Function) -> relay.Function: """This is the method that replaces the operations with hardware/codegen supported operations. """ - mod = LegalizeSplit()(mod) - mod = LegalizeConv2D()(mod) - mod = LegalizeConv2DTranspose()(mod) - mod = LegalizeDepthwiseConv2D()(mod) - mod = LegalizeMaxPooling()(mod) - mod = LegalizeAvgPooling()(mod) - mod = LegalizeAdd()(mod) - mod = LegalizeSub()(mod) - mod = LegalizeMul()(mod) - mod = LegalizeMin()(mod) - mod = LegalizeMax()(mod) - mod = LegalizeShl()(mod) - mod = LegalizeAbs()(mod) - mod = LegalizeTanh()(mod) - mod = LegalizeLeakyReLU()(mod) - mod = LegalizeMean()(mod) - mod = LegalizeConcat()(mod) - mod = LegalizeSigmoid()(mod) - mod = LegalizeRequantize()(mod) - mod = LegalizeResize2d()(mod) - mod = LegalizeExpandDims()(mod) - mod = LegalizeSqueeze()(mod) - mod = LegalizeReshape()(mod) - mod = LegalizeStridedSlice()(mod) - mod = LegalizeFullyConnected()(mod) - mod = LegalizeNoOps()(mod) - return mod + rewriters = [ + PartitionedSplitRewriter(), + SplitRewriter(), + Conv2DRewriter(), + Conv2DTransposeRewriter(), + DepthwiseConv2DRewriter(), + FullyConnectedRewriter(), + MaxPoolingRewriter(), + AvgPoolingRewriter(), + AddRewriter(), + SubRewriter(), + MulRewriter(), + MinRewriter(), + MaxRewriter(), + ShlRewriter(), + AbsRewriter(), + TanhRewriter(), + LeakyReLURewriter(), + MeanRewriter(), + ConcatRewriter(), + SigmoidRewriter(), + RequantizeRewriter(), + Resize2dRewriter(), + ExpandDimsRewriter(), + SqueezeRewriter(), + ReshapeRewriter(), + StridedSliceRewriter(), + NoOpRewriter(), + ] + for rewriter in rewriters: + func = rewrite(rewriter, func) + + return func def __call__(self, *args, **kwargs): # pylint is unable figure out the decorated diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index dffc237e791c..5b682eb14230 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -143,6 +143,11 @@ class QDenseArgs(Enum): WEIGHTS_SCALE = 5 +def is_npu_func(func: relay.Function) -> bool: + """Check if the given function is an NPU function.""" + return func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u" + + def is_composite_func(func: relay.Function, name: str) -> bool: """ This method checks whether the call is to @@ -313,3 +318,37 @@ def __init__( encoded_constants, base_addresses, ) + + +def npu_pass(opt_level: int, name: str = ""): + """ + A utility decorator that wraps a given class as an NPU module pass. That is, + a pass that behaves like a module pass but only traverses NPU external + functions. How NPU functions are mutated is defined by `transform_npu_function`. + + Parameters + ---------- + opt_level: int + Optimization level for the module pass. + name: str, optional + Name for the module pass. + + Returns + ------- + decorator + The npu_pass decorator. + """ + + def decorator(npu_pass_class): + @tvm.ir.transform.module_pass(name=name, opt_level=opt_level) + class ModulePassWrapper: + def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.ir.IRModule: + npu_functions = filter(lambda x: is_npu_func(x[1]), mod.functions.items()) + for global_var, func in npu_functions: + func = npu_pass_class().transform_npu_function(global_var, func) + mod.update_func(global_var, func) + return mod + + return ModulePassWrapper + + return decorator diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index ca41ccd14257..7044669d23b5 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -48,59 +48,63 @@ namespace contrib { namespace ethosu { /*! - * \brief This mutator lowers each external - * relay function to a TIR PrimFunc + * \brief This mutator outlines functions that are marked with a named + * "Compiler" attribute. Functions that do not match this condition remain + * unaltered. */ -class RelayToTIRMutator : public MixedModeMutator { +class OutlineCompilerFunctionsMutator : public MixedModeMutator { public: - explicit RelayToTIRMutator(IRModule ir_module) : ir_module_(ir_module) {} - - IRModule operator()() { - GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - Function main = Downcast(ir_module_->Lookup(main_global_var)); - Function mutated_main = WithFields(main, main->params, VisitExpr(main->body)); - - ir_module_->Update(main_global_var, mutated_main); - ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_); - return ir_module_; - } + explicit OutlineCompilerFunctionsMutator(const IRModule& mod, const std::string& compiler_name) + : mod_(mod), compiler_name_(compiler_name) {} Expr Rewrite_(const CallNode* pre, const Expr& post) override { Call call = Downcast(post); if (call->op->IsInstance()) { Function func = Downcast(call->op); - auto codegen_name = func->GetAttr(attr::kCompiler); - if (codegen_name.defined() && codegen_name == "ethos-u") { - auto relay_to_tir_func_pf = - tvm::runtime::Registry::Get("relay.ext.ethos-u.relay_to_tir_func"); - ICHECK(relay_to_tir_func_pf); - tir::PrimFunc prim_func = (*relay_to_tir_func_pf)(func); - prim_func = WithAttr(prim_func, tvm::attr::kTarget, Target("ethos-u")); - String symbol_name = prim_func->GetAttr(tvm::attr::kGlobalSymbol).value(); - GlobalVar gv(symbol_name); - Array args = call->args; - gv->checked_type_ = func->checked_type(); - ir_module_->Update(gv, prim_func); - device_contexts_.Set(gv, codegen_name.value()); - return Call(gv, args, call->attrs, call->type_args); + auto compiler = func->GetAttr(attr::kCompiler); + if (compiler.defined() && compiler == compiler_name_) { + auto gv_name = func->GetAttr("global_symbol").value_or(""); + ICHECK_NE(gv_name, "") + << "Function to be outlined must have global_symbol attribute, but didn't."; + GlobalVar gv(gv_name); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type(); + } + mod_->Update(gv, func); + return Call(gv, call->args, call->attrs, call->type_args); } } return post; } private: - IRModule ir_module_; - Map device_contexts_; + IRModule mod_; + std::string compiler_name_; }; -tvm::transform::Pass RelayToTIR() { +/*! + * \brief A pass to outline compiler specific functions. + */ +tvm::transform::Pass OutlineCompilerFunctions(const std::string& compiler_name) { runtime::TypedPackedFunc pass_func = - [=](IRModule ir_module, transform::PassContext pass_context) { - return RelayToTIRMutator(ir_module)(); + [=](IRModule mod, transform::PassContext ctx) { + GlobalVar gv = mod->GetGlobalVar("main"); + Function main_func = Downcast(mod->Lookup("main")); + auto new_main_body = + OutlineCompilerFunctionsMutator(mod, compiler_name).VisitExpr(main_func->body); + if (!new_main_body.same_as(main_func->body)) { + Function new_main_func = WithFields(main_func, main_func->params, new_main_body); + mod->Update(gv, new_main_func); + } + return mod; }; - return tvm::transform::CreateModulePass(pass_func, 0, "relay.contrib.ethos-u.RelayToTIR", {}); + return tvm::transform::CreateModulePass( + pass_func, 0, "relay.backend.contrib.ethos-u.OutlineCompilerFunctions", {}); } +TVM_REGISTER_GLOBAL("relay.ext.ethos-u.OutlineCompilerFunctions") + .set_body_typed(OutlineCompilerFunctions); + /*! * \brief This mutator removes identity operations that are not necessary. Specifically, an * identity operation can be removed when it is immediately followed by an NPU compute @@ -161,11 +165,14 @@ tvm::transform::Pass IdentityOptimizer() { runtime::TypedPackedFunc pass_func = [=](IRModule mod, transform::PassContext ctx) { for (auto gv : mod->GetGlobalVars()) { - Function main_func = Downcast(mod->Lookup(gv)); - auto new_main_body = RemoveRedundantIdentities().VisitExpr(main_func->body); - if (!new_main_body.same_as(main_func->body)) { - Function new_main_func = WithFields(main_func, main_func->params, new_main_body); - mod->Update(gv, new_main_func); + Function func = Downcast(mod->Lookup(gv)); + auto compiler_name = func->GetAttr(attr::kCompiler); + if (compiler_name.defined() && compiler_name == "ethos-u") { + auto new_body = RemoveRedundantIdentities().VisitExpr(func->body); + if (!new_body.same_as(func->body)) { + Function new_func = WithFields(func, func->params, new_body); + mod->Update(gv, new_func); + } } } return mod; @@ -176,6 +183,20 @@ tvm::transform::Pass IdentityOptimizer() { TVM_REGISTER_GLOBAL("relay.ext.ethos-u.IdentityOptimizer").set_body_typed(IdentityOptimizer); +/*! + * \brief This pass will lower NPU functions in a Relay module to scheduled TIR prim functions. + */ +tvm::transform::Pass RelayToTIR() { + runtime::TypedPackedFunc pass_func = + [=](IRModule ir_module, transform::PassContext pass_context) { + auto relay_to_tir_pf = tvm::runtime::Registry::Get("relay.ext.ethos-u.relay_to_tir"); + ICHECK(relay_to_tir_pf); + ir_module = (*relay_to_tir_pf)(ir_module); + return ir_module; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "relay.contrib.ethos-u.RelayToTIR", {}); +} + /*! * \brief This function lowers the IRModule with PrimFunc * with the target of the microNPU to a C-source runtime module diff --git a/tests/python/contrib/test_ethosu/test_identity_optimizer.py b/tests/python/contrib/test_ethosu/test_identity_optimizer.py index 833b8d089dc8..a2bb4f465a8a 100644 --- a/tests/python/contrib/test_ethosu/test_identity_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_identity_optimizer.py @@ -28,21 +28,22 @@ import tvm from tvm import relay from tvm.relay.op.contrib.ethosu import partition_for_ethosu -from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func +from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir from tvm.relay.backend.contrib.ethosu.codegen import IdentityOptimizer from . import infra from .test_codegen import _compare_tvm_with_tflite -def _optimize(expr, optimize=True): +def _optimize(func, optimize=True): """Create IRModule and run identity optimizer pass.""" - mod = tvm.IRModule.from_expr(expr) + func = func.with_attr("Compiler", "ethos-u") + mod = tvm.IRModule.from_expr(func) mod = relay.transform.InferType()(mod) if optimize: mod = IdentityOptimizer()(mod) entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body + return entry if isinstance(func, relay.Function) else entry.body def _assert_structural_equal(a, b): @@ -266,7 +267,7 @@ def get_graph(get_expected=False): _assert_structural_equal(actual, expected) -def test_layout_optimizer_runs_in_compilation_pipeline(): +def test_identity_optimizer_runs_in_compilation_pipeline(): """Checks that the identity optimization pass is run as part of the NPU compilation pipeline.""" def get_graph(): @@ -278,10 +279,10 @@ def get_graph(): mod = get_graph() mod = partition_for_ethosu(mod) + mod = relay_to_tir(mod) external_gv_name = mod["main"].body.op.name_hint - external_func = mod[external_gv_name] - prim_func = relay_to_tir_func(external_func) + prim_func = mod[external_gv_name] # Check for hints in the TIR prim func that the identity optimization pass # has ran. There should not be an identity in the prim func. diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py b/tests/python/contrib/test_ethosu/test_layout_optimizer.py index 9199cdd7f014..a2161c775b06 100644 --- a/tests/python/contrib/test_ethosu/test_layout_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py @@ -33,19 +33,20 @@ from tvm import relay from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tvm.relay.backend.contrib.ethosu.codegen import LayoutOptimizer -from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func +from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir from . import infra -def _optimize(expr, optimize=True): +def _optimize(func, optimize=True): """Create IRModule and run layout optimizer pass.""" - mod = tvm.IRModule.from_expr(expr) + func = func.with_attr("Compiler", "ethos-u") + mod = tvm.IRModule.from_expr(func) mod = relay.transform.InferType()(mod) if optimize: mod = LayoutOptimizer()(mod) entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body + return entry if isinstance(func, relay.Function) else entry.body def _assert_structural_equal(a, b): @@ -721,10 +722,10 @@ def get_graph(): mod = get_graph() mod = partition_for_ethosu(mod) + mod = relay_to_tir(mod) external_gv_name = mod["main"].body.op.name_hint - external_func = mod[external_gv_name] - prim_func = relay_to_tir_func(external_func) + prim_func = mod[external_gv_name] # Check for hints in the TIR prim func that the layout optimization pass has ran ops = prim_func.body.body.seq diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 710c3e8c8812..32cf2c1e9255 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -102,15 +102,21 @@ def @tvmgen_default_ethos_u_main_0(%x: Tensor[(1, 50, 50, 3), float32]) -> (Tens """ return tvm.parser.fromtext(expected_ir_string) + rewrite_split = [legalize.PartitionedSplitRewriter(), legalize.SplitRewriter()] + mod_axis1 = tvm.IRModule() - mod_axis1["tvmgen_default_ethos_u_main_0"] = create_graph(1) - mod_axis1 = legalize.LegalizeSplit()(mod_axis1) + func = create_graph(1) + for r in rewrite_split: + func = dataflow_pattern.rewrite(r, func) + mod_axis1["tvmgen_default_ethos_u_main_0"] = func expected_axis1 = expected_mod_axis1() tvm.ir.assert_structural_equal(mod_axis1, expected_axis1) mod_axis2 = tvm.IRModule() - mod_axis2["tvmgen_default_ethos_u_main_0"] = create_graph(2) - mod_axis2 = legalize.LegalizeSplit()(mod_axis2) + func = create_graph(2) + for r in rewrite_split: + func = dataflow_pattern.rewrite(r, func) + mod_axis2["tvmgen_default_ethos_u_main_0"] = func expected_axis2 = expected_mod_axis2() tvm.ir.assert_structural_equal(mod_axis2, expected_axis2) @@ -198,15 +204,21 @@ def @tvmgen_default_ethos_u_main_0(%x: Tensor[(1, 50, 50, 3), float32]) -> (Tens """ return tvm.parser.fromtext(expected_ir_string) + rewrite_split = [legalize.PartitionedSplitRewriter(), legalize.SplitRewriter()] + mod_axis1 = tvm.IRModule() - mod_axis1["tvmgen_default_ethos_u_main_0"] = create_graph(1, 5) - mod_axis1 = legalize.LegalizeSplit()(mod_axis1) + func = create_graph(1, 5) + for r in rewrite_split: + func = dataflow_pattern.rewrite(r, func) + mod_axis1["tvmgen_default_ethos_u_main_0"] = func expected_axis1 = expected_mod_axis1() tvm.ir.assert_structural_equal(mod_axis1, expected_axis1) mod_axis2 = tvm.IRModule() - mod_axis2["tvmgen_default_ethos_u_main_0"] = create_graph(2, 5) - mod_axis2 = legalize.LegalizeSplit()(mod_axis2) + func = create_graph(2, 5) + for r in rewrite_split: + func = dataflow_pattern.rewrite(r, func) + mod_axis2["tvmgen_default_ethos_u_main_0"] = func expected_axis2 = expected_mod_axis2() tvm.ir.assert_structural_equal(mod_axis2, expected_axis2) diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py index d9a543c1a771..db2a1d5a88a9 100644 --- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py @@ -27,7 +27,7 @@ import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.codegen import LUTsOptimizer -from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func +from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir from tvm.relay.op.contrib.ethosu import partition_for_ethosu from .test_codegen import _get_tflite_graph @@ -49,6 +49,7 @@ def before(): id2 = infra.make_ethosu_identity(conv2, lut=lut2, activation="SIGMOID") func = relay.Function(relay.analysis.free_vars(id2), id2) + func = func.with_attr("Compiler", "ethos-u") mod = tvm.IRModule.from_expr(func) return mod @@ -61,6 +62,7 @@ def after(): ) func = relay.Function(relay.analysis.free_vars(conv2), conv2) + func = func.with_attr("Compiler", "ethos-u") mod = tvm.IRModule.from_expr(func) mod = relay.transform.InferType()(mod) return mod @@ -84,6 +86,7 @@ def before(): id2 = infra.make_ethosu_identity(id1, lut=lut2, activation="TANH") func = relay.Function(relay.analysis.free_vars(id2), id2) + func = func.with_attr("Compiler", "ethos-u") mod = tvm.IRModule.from_expr(func) return mod @@ -94,6 +97,7 @@ def after(): id2 = infra.make_ethosu_identity(conv1, lut=lut2, activation="TANH") func = relay.Function(relay.analysis.free_vars(id2), id2) + func = func.with_attr("Compiler", "ethos-u") mod = tvm.IRModule.from_expr(func) mod = relay.transform.InferType()(mod) return mod @@ -119,10 +123,10 @@ def get_graph(x): mod, _ = _get_tflite_graph(get_graph, [ifm_shape]) mod = partition_for_ethosu(mod) + mod = relay_to_tir(mod) external_gv_name = mod["main"].body.op.name_hint - external_func = mod[external_gv_name] - prim_func = relay_to_tir_func(external_func) + prim_func = mod[external_gv_name] # Check for hints in the TIR prim func that the LUT optimization pass has ran. # If the module was optimized, there should be no identity operations. diff --git a/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py new file mode 100644 index 000000000000..29c40cbd2891 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Test the outline compiler functions pass. +""" + +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu.codegen import OutlineCompilerFunctions + + +def test_outline_compiler_functions(): + compiler_name = "my-compiler" + + def before(): + inp = relay.var("input") + + x = relay.var("x", shape=(1, 2, 2, 4)) + x = relay.reshape(x, newshape=(1, 4, 4)) + x = relay.Function(relay.analysis.free_vars(x), x) + x = x.with_attr("Compiler", compiler_name) + x = x.with_attr("global_symbol", "ext_func") + + out = relay.Call(x, [inp]) + out = relay.Function([inp], out) + return tvm.ir.IRModule.from_expr(out) + + def expected(): + mod = tvm.ir.IRModule() + + inp = relay.var("input") + + x = relay.var("x", shape=(1, 2, 2, 4)) + x = relay.reshape(x, newshape=(1, 4, 4)) + x = relay.Function(relay.analysis.free_vars(x), x) + x = x.with_attr("Compiler", compiler_name) + x = x.with_attr("global_symbol", "ext_func") + mod["ext_func"] = x + + out = relay.Call(mod.get_global_var("ext_func"), [inp]) + mod["main"] = relay.Function([inp], out) + return mod + + after = OutlineCompilerFunctions(compiler_name)(before()) + exp = expected() + assert after["ext_func"] + assert tvm.ir.structural_equal(after["ext_func"], exp["ext_func"]) From 0b7977148135352a072257dbb4efb70431857cca Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 14 Mar 2022 15:02:07 +0000 Subject: [PATCH 2/4] add mixed compilers to test Change-Id: I3ca48738e096bb0f4dc362f0e9550317fc0d5afd --- .../test_outline_compiler_functions.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py index 29c40cbd2891..91458f60e172 100644 --- a/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py +++ b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py @@ -19,6 +19,10 @@ Test the outline compiler functions pass. """ +import pytest + +pytest.importorskip("ethosu.vela") + import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.codegen import OutlineCompilerFunctions @@ -26,17 +30,27 @@ def test_outline_compiler_functions(): compiler_name = "my-compiler" + wrong_compiler_name = "wrong-compiler" def before(): inp = relay.var("input") + # Inlined functions for "my-compiler" x = relay.var("x", shape=(1, 2, 2, 4)) x = relay.reshape(x, newshape=(1, 4, 4)) x = relay.Function(relay.analysis.free_vars(x), x) x = x.with_attr("Compiler", compiler_name) x = x.with_attr("global_symbol", "ext_func") + # Inlined function for "wrong-compiler" + y = relay.var("y", shape=(1, 4, 4)) + y = relay.reshape(y, newshape=(1, 16)) + y = relay.Function(relay.analysis.free_vars(y), y) + y = y.with_attr("Compiler", wrong_compiler_name) + y = y.with_attr("global_symbol", "ext_func_2") + out = relay.Call(x, [inp]) + out = relay.Call(y, [out]) out = relay.Function([inp], out) return tvm.ir.IRModule.from_expr(out) @@ -52,11 +66,21 @@ def expected(): x = x.with_attr("global_symbol", "ext_func") mod["ext_func"] = x + y = relay.var("y", shape=(1, 4, 4)) + y = relay.reshape(y, newshape=(1, 16)) + y = relay.Function(relay.analysis.free_vars(y), y) + y = y.with_attr("Compiler", wrong_compiler_name) + y = y.with_attr("global_symbol", "ext_func_2") + out = relay.Call(mod.get_global_var("ext_func"), [inp]) + out = relay.Call(y, [out]) mod["main"] = relay.Function([inp], out) return mod after = OutlineCompilerFunctions(compiler_name)(before()) exp = expected() - assert after["ext_func"] + + global_vars = [str(gv) for gv in after.get_global_vars()] + assert "@ext_func" in global_vars + assert "@ext_func_2" not in global_vars assert tvm.ir.structural_equal(after["ext_func"], exp["ext_func"]) From 92dee34817d9a1c87c8458ba7df1db133529c7da Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 17 Mar 2022 11:42:49 +0000 Subject: [PATCH 3/4] Address comments including renaming both npu_pass and RelayToTIR This commit renames `npu_pass` -> `create_npu_function_pass`. It also renames the `RelayToTIR` pass created in Python to `LowerToTIR`, along with moving it to compiler.py to make it clear that this pass is a wrapper around the `_lower_to_tir` function. In addition, to make it explicit that the `lower_to_tir` func->func pass should not be used directly it has been renamed to `_lower_to_tir` - it is being maintained since it is used in many tests. Change-Id: I3a0a06801f029aeaa4a51c2d86d8703bb0d7afbb --- .../relay/backend/contrib/ethosu/codegen.py | 39 +++++-------------- .../relay/backend/contrib/ethosu/legalize.py | 2 +- .../backend/contrib/ethosu/tir/compiler.py | 38 +++++++++++++++++- .../tvm/relay/backend/contrib/ethosu/util.py | 32 ++++++++++++--- .../contrib/test_ethosu/test_compiler.py | 4 +- .../test_ethosu/test_encode_constants.py | 12 +++--- .../test_ethosu/test_remove_concatenates.py | 4 +- .../test_replace_binary_elementwise.py | 6 +-- .../test_ethosu/test_replace_conv2d.py | 12 +++--- .../contrib/test_ethosu/test_replace_copy.py | 6 +-- .../test_replace_depthwise_conv2d.py | 4 +- .../test_ethosu/test_replace_identity.py | 4 +- .../test_ethosu/test_replace_pooling.py | 6 +-- .../test_replace_unary_elementwise.py | 4 +- .../contrib/test_ethosu/test_scheduler.py | 4 +- 15 files changed, 107 insertions(+), 70 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 9ef18dd5fee5..123a92d96f56 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -19,7 +19,7 @@ import tvm from tvm import relay -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import LowerToTIR from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator @@ -111,7 +111,7 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return new_call -@util.npu_pass(opt_level=1) +@util.create_npu_function_pass(opt_level=1) class LUTsOptimizer: """Register LUTsOptimizer as a relay pass.""" @@ -265,7 +265,7 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return super().visit_call(call) -@util.npu_pass(opt_level=1) +@util.create_npu_function_pass(opt_level=1) class LayoutOptimizer: """Register LayoutOptimizer as a Relay pass.""" @@ -318,32 +318,6 @@ def OutlineCompilerFunctions(compiler_name): # pylint: disable=invalid-name return _ffi_api.OutlineCompilerFunctions(compiler_name) -@util.npu_pass(opt_level=1) -class RelayToTIR: - """Register RelayToTIR pass.""" - - def transform_npu_function(self, _, func: relay.Function) -> relay.Function: - """Lower NPU functions to TIR.""" - # We are currently using copy_constants scheduler In the long run, - # this should be a single intelligent and a composite scheduler - # that can perform scheduling based on user inputs such as - # scratch memory size. - tir_mod, const_dict = lower_to_tir(func, copy_constants()) - - for param in const_dict.keys(): - const_dict[param] = tvm.nd.array(const_dict[param]) - - compiler_name = "ethos-u" - primfunc = tir_mod["main"] - primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"]) - primfunc = primfunc.with_attr("ethos-u.constants", const_dict) - primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name)) - return primfunc - - def __call__(self, *args, **kwargs): - pass - - @tvm._ffi.register_func("relay.ext.ethos-u.constant_updater") def constant_updater(expr, symbol): # pylint: disable=unused-argument """ @@ -381,7 +355,12 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: gv: "ethos-u" for gv, _ in filter(lambda x: util.is_npu_func(x[1]), mod.functions.items()) } mod = mod.with_attr("device_contexts", device_contexts) - mod = RelayToTIR()(mod) + + # We are currently using copy_constants scheduler In the long run, + # this should be a single intelligent and a composite scheduler + # that can perform scheduling based on user inputs such as + # scratch memory size. + mod = LowerToTIR(copy_constants)(mod) return mod diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 4e2737754be2..6f37b90f0f97 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1252,7 +1252,7 @@ def callback(self, pre, post, node_map): return ethosu_fc -@util.npu_pass(opt_level=1) +@util.create_npu_function_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation in a way such that the operations are replaced with hardware/codegen supported diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index bdc3b3186202..aa15d916ee98 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -23,6 +23,7 @@ from . import passes as ethosu_passes from .scheduler import schedule +from .. import util def lower_ethosu(sch, args, const_dict, name="main"): @@ -172,7 +173,42 @@ def extract_constants(func): return new_func, const_dict -def lower_to_tir(func, cascader=None): +@util.create_npu_function_pass(opt_level=1) +class LowerToTIR: + """A pass that lowers NPU Relay functions to TIR. This pass wraps + the _lower_to_tir pass that operates function->function, while this + is IRModule->IRModule. + + Attributes + ---------- + scheduler : callable + A function to schedule NPU operations. For example, + scheduler.py/copy_constants. + """ + + def __init__(self, scheduler): + self.scheduler = scheduler + + def transform_npu_function(self, _, func: relay.Function) -> relay.Function: + """Lower NPU functions to TIR.""" + + tir_mod, const_dict = _lower_to_tir(func, self.scheduler()) + + for param in const_dict.keys(): + const_dict[param] = tvm.nd.array(const_dict[param]) + + compiler_name = "ethos-u" + primfunc = tir_mod["main"] + primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"]) + primfunc = primfunc.with_attr("ethos-u.constants", const_dict) + primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name)) + return primfunc + + def __call__(self, *args, **kwargs): + pass + + +def _lower_to_tir(func, cascader=None): """Lower a Relay function to TIR for the Arm(R) Ethos(TM)-U NPU target. The Relay function should only contain operations supported diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 5b682eb14230..16b215143d9e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -320,11 +320,26 @@ def __init__( ) -def npu_pass(opt_level: int, name: str = ""): +def create_npu_function_pass(opt_level: int, name: str = ""): """ - A utility decorator that wraps a given class as an NPU module pass. That is, - a pass that behaves like a module pass but only traverses NPU external - functions. How NPU functions are mutated is defined by `transform_npu_function`. + A utility decorator that wraps a given class as an NPU function pass. That is, + a pass that behaves like a function pass and only traverses NPU external + functions. How each NPU function is mutated is defined by the + `transform_npu_function(global_variable, relay_function)` function which should + be created in the class that is to be decorated. See the example below. + + Example + ------- + This small example demonstrates a pass over NPU functions that performs no + mutation. + + @create_npu_function_pass(opt_level=1) + class MyPass: + def transform_npu_function(global_var, func): + return func + + mod = tvm.IRModule() + mod = MyPass()(mod) Parameters ---------- @@ -342,10 +357,17 @@ def npu_pass(opt_level: int, name: str = ""): def decorator(npu_pass_class): @tvm.ir.transform.module_pass(name=name, opt_level=opt_level) class ModulePassWrapper: + """The wrapper for the NPU pass.""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.ir.IRModule: npu_functions = filter(lambda x: is_npu_func(x[1]), mod.functions.items()) for global_var, func in npu_functions: - func = npu_pass_class().transform_npu_function(global_var, func) + npu_pass = npu_pass_class(*self.args, **self.kwargs) + func = npu_pass.transform_npu_function(global_var, func) mod.update_func(global_var, func) return mod diff --git a/tests/python/contrib/test_ethosu/test_compiler.py b/tests/python/contrib/test_ethosu/test_compiler.py index 0e31be86becb..5da91632bd86 100644 --- a/tests/python/contrib/test_ethosu/test_compiler.py +++ b/tests/python/contrib/test_ethosu/test_compiler.py @@ -19,7 +19,7 @@ pytest.importorskip("ethosu.vela") import tvm from tvm import relay -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from . import infra @@ -57,7 +57,7 @@ def test_lower_to_tir_arg_count(relay_function, arg_count): mod = tvm.IRModule() mod["main"] = relay_function() mod = relay.transform.InferType()(mod) - tir_mod = lower_to_tir(mod["main"])[0] + tir_mod = _lower_to_tir(mod["main"])[0] primfunc = tir_mod["main"] assert len(primfunc.params) == arg_count diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 8878e467aad7..760f37505605 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.script import tir as T from tvm.relay.testing import run_opt_pass -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator @@ -96,7 +96,7 @@ def _get_func(): return func func = _get_func() - mod, consts = lower_to_tir(func, cascader=_planner) + mod, consts = _lower_to_tir(func, cascader=_planner) script = mod.script(show_meta=True) test_mod = tvm.script.from_source(script) reference_mod = WeightStreamOnly @@ -159,7 +159,7 @@ def _get_func(): return func func = _get_func() - mod, consts = lower_to_tir(func, cascader=_cascader) + mod, consts = _lower_to_tir(func, cascader=_cascader) script = mod.script(show_meta=True) test_mod = tvm.script.from_source(script) reference_mod = RereadWeights @@ -217,7 +217,7 @@ def _get_func(): return func func = _get_func() - mod, consts = lower_to_tir(func) + mod, consts = _lower_to_tir(func) script = mod.script(show_meta=True) test_mod = tvm.script.from_source(script) @@ -306,7 +306,7 @@ def _get_func(): return func func = _get_func() - mod, consts = lower_to_tir(func, cascader=_planner) + mod, consts = _lower_to_tir(func, cascader=_planner) script = mod.script(show_meta=True) test_mod = tvm.script.from_source(script) @@ -353,7 +353,7 @@ def get_graph(): func = run_opt_pass(func, relay.transform.InferType()) return func - tir_mod, params = lower_to_tir(get_graph(), copy_constants()) + tir_mod, params = _lower_to_tir(get_graph(), copy_constants()) # Check tile address for the scalar constant input hasn't been # overwritten. diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index f82351c28c05..355b7564952e 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -22,7 +22,7 @@ from tvm.script import tir as T from tvm import relay from tvm.relay.testing import run_opt_pass -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from .infra import make_ethosu_conv2d @@ -69,7 +69,7 @@ def _get_func(): return func func = _get_func() - mod, _ = lower_to_tir(func) + mod, _ = _lower_to_tir(func) script = mod.script(show_meta=True) test_mod = tvm.script.from_source(script) diff --git a/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py index 7d4005482a60..b518f513144e 100644 --- a/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py +++ b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir import spec -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from .infra import make_ethosu_binary_elementwise, get_binary_elementwise_args @@ -71,7 +71,7 @@ def test_binary_elementwise_single( ) func = relay.Function(relay.analysis.free_vars(binary_elementwise), binary_elementwise) func = run_opt_pass(func, relay.transform.InferType()) - mod, _ = lower_to_tir(func) + mod, _ = _lower_to_tir(func) data = [] def _visit(stmt): @@ -227,7 +227,7 @@ def test_shift_binary_elementwise_single( ) func = relay.Function(relay.analysis.free_vars(binary_elementwise), binary_elementwise) func = run_opt_pass(func, relay.transform.InferType()) - mod, _ = lower_to_tir(func) + mod, _ = _lower_to_tir(func) data = [] def _visit(stmt): diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 5a9aa9855183..b51c932f2c8e 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -21,7 +21,7 @@ from tvm.script import tir as T from tvm import relay from tvm.relay.testing import run_opt_pass -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import total_cascader from .infra import make_ethosu_conv2d, get_convolutional_args @@ -316,7 +316,7 @@ def _get_func( [(1, 2, 12, 9, 16), 182, 67, (1, 3), (6, 3), (2, 2), (1, 1), "CLIP", "NHCWB16", "NHCWB16"], ] func = _get_func(*trial) - mod, _ = lower_to_tir(func) + mod, _ = _lower_to_tir(func) data = [] def _visit(stmt): @@ -593,7 +593,7 @@ def _get_func( reference_mod = trial[0] params = trial[1:] func = _get_func(*params[:-1]) - mod, _ = lower_to_tir(func, cascader=total_cascader(params[-1])) + mod, _ = _lower_to_tir(func, cascader=total_cascader(params[-1])) script = mod.script(show_meta=True) mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) @@ -652,7 +652,7 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): reference_mod = trial[0] params = trial[1:] func = _get_func(*params) - mod, _ = lower_to_tir(func) + mod, _ = _lower_to_tir(func) script = mod.script(show_meta=True) mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) @@ -755,7 +755,7 @@ def _get_func(ifm_shape, reshaped, ifm_layout): reference_mod = trial[0] params = trial[1:] func = _get_func(*params) - mod, _ = lower_to_tir(func, cascader=total_cascader((1, 4, 6, 16))) + mod, _ = _lower_to_tir(func, cascader=total_cascader((1, 4, 6, 16))) script = mod.script(show_meta=True) mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) @@ -775,7 +775,7 @@ def _get_func(): return func func = _get_func() - mod, _ = lower_to_tir(func, cascader=total_cascader((1, 4, 4, 16))) + mod, _ = _lower_to_tir(func, cascader=total_cascader((1, 4, 4, 16))) if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 4bfbae5f03b7..92b294069a90 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -21,7 +21,7 @@ from tvm.script import tir as T from tvm import relay from tvm.relay.testing import run_opt_pass -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants, Convolution2DCompute from .infra import make_ethosu_conv2d @@ -65,7 +65,7 @@ def _get_func(): return func func = _get_func() - mod, _ = lower_to_tir(func, cascader=copy_constants()) + mod, _ = _lower_to_tir(func, cascader=copy_constants()) script = mod.script(show_meta=True) test_mod = tvm.script.from_source(script) @@ -129,7 +129,7 @@ def _get_func(): return func func = _get_func() - mod, _ = lower_to_tir(func, cascader=_cascader) + mod, _ = _lower_to_tir(func, cascader=_cascader) script = mod.script(show_meta=True) test_mod = tvm.script.from_source(script) diff --git a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py index edbfb4939b11..fe11a0fb369b 100644 --- a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py @@ -22,7 +22,7 @@ import tvm from tvm import relay from tvm.relay.testing import run_opt_pass -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from .infra import make_ethosu_depthwise_conv2d, get_convolutional_args @@ -108,7 +108,7 @@ def _get_func( return func func = _get_func(*trial) - mod, _ = lower_to_tir(func) + mod, _ = _lower_to_tir(func) data = [] def _visit(stmt): diff --git a/tests/python/contrib/test_ethosu/test_replace_identity.py b/tests/python/contrib/test_ethosu/test_replace_identity.py index 1ce55c49ea96..e53230c6eb9a 100644 --- a/tests/python/contrib/test_ethosu/test_replace_identity.py +++ b/tests/python/contrib/test_ethosu/test_replace_identity.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir import spec -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from .infra import make_ethosu_identity, get_pooling_args @@ -33,7 +33,7 @@ def test_identity(ifm_shape): func = relay.Function(relay.analysis.free_vars(identity), identity) func = run_opt_pass(func, relay.transform.InferType()) - mod, _ = lower_to_tir(func) + mod, _ = _lower_to_tir(func) data = [] def _visit(stmt): diff --git a/tests/python/contrib/test_ethosu/test_replace_pooling.py b/tests/python/contrib/test_ethosu/test_replace_pooling.py index c535498ee04d..0680f0ce9de1 100644 --- a/tests/python/contrib/test_ethosu/test_replace_pooling.py +++ b/tests/python/contrib/test_ethosu/test_replace_pooling.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir import spec -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from .infra import make_ethosu_pooling, get_pooling_args @@ -181,7 +181,7 @@ def test_pooling_single( ) func = relay.Function(relay.analysis.free_vars(pooling), pooling) func = run_opt_pass(func, relay.transform.InferType()) - mod, _ = lower_to_tir(func) + mod, _ = _lower_to_tir(func) data = [] def _visit(stmt): @@ -241,7 +241,7 @@ def test_correct_stride_with_multiple_pooling(): ) func = relay.Function(relay.analysis.free_vars(op), op) func = run_opt_pass(func, relay.transform.InferType()) - mod, _ = lower_to_tir(func) + mod, _ = _lower_to_tir(func) data = [] diff --git a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py index 498609fb15b7..6240b54261f8 100644 --- a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py +++ b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir import spec -from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from .infra import make_ethosu_unary_elementwise @@ -69,7 +69,7 @@ def test_unary_elementwise_single( ) func = relay.Function(relay.analysis.free_vars(unary_elementwise), unary_elementwise) func = run_opt_pass(func, relay.transform.InferType()) - mod, _ = lower_to_tir(func) + mod, _ = _lower_to_tir(func) data = [] def _visit(stmt): diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 5c6f064873ef..06025910cd09 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -34,7 +34,7 @@ from tvm.relay.backend.contrib.ethosu.tir.compiler import ( lower_to_te, extract_constants, - lower_to_tir, + _lower_to_tir, ) from .infra import ( AttachType, @@ -216,7 +216,7 @@ def test_schedule_diamond_graph(): func = relay.Function(relay.analysis.free_vars(add), add) func = run_opt_pass(func, relay.transform.InferType()) - test_mod, _ = lower_to_tir(func, copy_constants()) + test_mod, _ = _lower_to_tir(func, copy_constants()) reference_mod = DiamondGraphTir tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) From 94fd3cde9a8efd771a84345a4482ef9d76e938da Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 18 Mar 2022 15:01:46 +0000 Subject: [PATCH 4/4] address nit and small fix to example Change-Id: I44c64de15fa8680cc89ce0440ffa6c9e0ec62a50 --- python/tvm/relay/backend/contrib/ethosu/codegen.py | 2 +- python/tvm/relay/backend/contrib/ethosu/util.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 123a92d96f56..e8b5cc23aff2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -341,7 +341,7 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: Returns ------- - mod : tvm.ir.TRModule + mod : tvm.ir.IRModule The Relay module with scheduled NPU external functions. """ mod = OutlineCompilerFunctions("ethos-u")(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 16b215143d9e..64c561ec7f2c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -335,7 +335,7 @@ def create_npu_function_pass(opt_level: int, name: str = ""): @create_npu_function_pass(opt_level=1) class MyPass: - def transform_npu_function(global_var, func): + def transform_npu_function(self, global_var, func): return func mod = tvm.IRModule()