From 200352262f7c3d5d20f943c051101d33448ed252 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Sat, 27 Nov 2021 18:40:36 +0000 Subject: [PATCH] [microNPU] Add support for SPLIT and SPLIT_V Both, SPLIT and SPLIT_V get lowered to relay.split and in the legalization the Relay split gets turned into strided slices. This patch adds the pattern and legalizer to enable offloading the TFLite's splits to the NPU. --- .../relay/backend/contrib/ethosu/legalize.py | 20 +++ python/tvm/relay/op/contrib/ethosu.py | 43 +++++ .../contrib/test_ethosu/test_codegen.py | 78 +++++++++ .../contrib/test_ethosu/test_legalize.py | 158 ++++++++++++++++++ 4 files changed, 299 insertions(+) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 4677b4469bfa9..1415311f5c5c0 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -109,6 +109,25 @@ def callback( return relay.Tuple(strided_slices) +class PartitionedSplitRewriter(DFPatternCallback): + """This pass brings the split out of the partitioned function""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.SplitParams.composite_name}) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + split_input = post.args[0] + split_params = ethosu_patterns.SplitParams(post.op.body) + indices_or_sections = split_params.indices_or_sections + axis = split_params.axis + return relay.op.split(split_input, indices_or_sections, axis=axis).astuple() + + @ir.transform.module_pass(opt_level=1) class LegalizeSplit: """This is the pass that wraps SplitRewriter""" @@ -117,6 +136,7 @@ def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: for global_var, func in mod.functions.items(): + func = rewrite(PartitionedSplitRewriter(), func) func = rewrite(SplitRewriter(), func) mod.update_func(global_var, func) return mod diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 82cc68bb48b3b..2b20beeb1bce8 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -996,6 +996,44 @@ def concat_pattern(): return optional_clip +class SplitParams: + """ + This class will parse a call to a ethos-u.split composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.split" + + def __init__(self, func_body): + self.split = func_body + self.input = TensorParams(func_body.args[0]) + self.axis = func_body.attrs.axis + self.indices_or_sections = self.convert_indices_or_sections( + func_body.attrs.indices_or_sections + ) + + def convert_indices_or_sections(self, indices_or_sections): + # split_v + if isinstance(indices_or_sections, tvm.ir.container.Array): + values = [i.value for i in indices_or_sections] + # split + else: + values = indices_or_sections.value + return values + + def is_valid(self): + """Checks whether split has compatible attributes with the hardware""" + if not check_valid_dtypes([self.input], supported_dtypes=[np.int8]): + return False + return True + + +def split_pattern(): + "Create the pattern for split" + split = is_op("split")(wildcard()) + return split + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1066,6 +1104,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal ), (TanhParams.composite_name, tanh_pattern(), lambda pat: TanhParams(pat).is_valid()), (ConcatParams.composite_name, concat_pattern(), lambda pat: ConcatParams(pat).is_valid()), + ( + SplitParams.composite_name, + split_pattern(), + lambda pat: SplitParams(pat).is_valid(), + ), ] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 3429ad7b65cd4..af71f4a8610c8 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1143,5 +1143,83 @@ def representative_dataset(): infra.verify_source(compiled_models, accel_type) +# This codegen test checks both, split and split_v +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), (1, 3, 4), 3), + ((4, 6, 8), 2, 0), + ((50,), 25, 0), + ((5, 11), 1, 1), + ((13,), (13,), 0), + ((22, 7), (4, -1), 1), + ], +) +def test_tflite_split(accel_type, ifm_shape, num_or_size_splits, axis): + dtype = "int8" + + def get_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, num_or_size_splits, axis): + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + model = Model() + + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32), num_or_size_splits, axis + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = get_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 9ac515c21eb6c..437da0ed56016 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -1139,5 +1139,163 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), 3, 2), + ((4, 6, 8), 2, 0), + ((5, 15), 3, 1), + ((3, 7), 1, 1), + ((100,), 25, 0), + ], +) +def test_tflite_split_legalize(ifm_shape, num_or_size_splits, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, num_or_size_splits, axis): + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + # dig out the split + single_output_split = num_or_size_splits == 1 + split = ( + ext_func.body.tuple_value + if single_output_split + else ext_func.body.args[0][0].args[0].tuple_value + ) + assert split.op.name == "split" + + # Split is specified by number of equal chunks + assert split.attrs.indices_or_sections == num_or_size_splits + + assert split.attrs.axis == axis + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = ethosu.partition_for_ethosu(mod) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), (1, 3, 4), 3), + ((10, 18, 4), (1, 4, 3, 2), 0), + ((22, 7), (4, -1), 1), + ((25,), (25,), 0), + ], +) +def test_tflite_split_v_legalize(ifm_shape, num_or_size_splits, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, num_or_size_splits, axis): + # TF split gets converted into TFLite's split_v + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + # dig out the split + single_output_split = len(num_or_size_splits) == 1 + split = ( + ext_func.body.tuple_value + if single_output_split + else ext_func.body.args[0][0].args[0].tuple_value + ) + assert split.op.name == "split" + + # Split is specified by the size of sections, so converting num_or_size_splits + # into the indices where the tensor is split at since this is how split is represented + # in Relay + split_sections = [] if single_output_split else [num_or_size_splits[0]] + for split_size in num_or_size_splits[1:-1]: + sec = split_sections[-1] + split_size + split_sections.append(sec) + assert list(split.attrs.indices_or_sections) == split_sections + + assert split.attrs.axis == axis + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = ethosu.partition_for_ethosu(mod) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__])