diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 0db8db912a51f..ede9cd46371e4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -108,6 +108,25 @@ def callback( return relay.Tuple(strided_slices) +class PartitionedSplitRewriter(DFPatternCallback): + """This pass brings the split out of the partitioned function""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.SplitParams.composite_name}) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + split_input = post.args[0] + split_params = ethosu_patterns.SplitParams(post.op.body) + indices_or_sections = split_params.indices_or_sections + axis = split_params.axis + return relay.op.split(split_input, indices_or_sections, axis=axis).astuple() + + @ir.transform.module_pass(opt_level=1) class LegalizeSplit: """This is the pass that wraps SplitRewriter""" @@ -116,6 +135,7 @@ def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: for global_var, func in mod.functions.items(): + func = rewrite(PartitionedSplitRewriter(), func) func = rewrite(SplitRewriter(), func) mod.update_func(global_var, func) return mod diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index a7d3da3200b53..73007cffe7268 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1107,6 +1107,44 @@ def concat_pattern(): return concat +class SplitParams: + """ + This class will parse a call to a ethos-u.split composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.split" + + def __init__(self, func_body): + self.split = func_body + self.input = TensorParams(func_body.args[0]) + self.axis = func_body.attrs.axis + self.indices_or_sections = self.convert_indices_or_sections( + func_body.attrs.indices_or_sections + ) + + def convert_indices_or_sections(self, indices_or_sections): + # split_v + if isinstance(indices_or_sections, tvm.ir.container.Array): + values = [i.value for i in indices_or_sections] + # split + else: + values = indices_or_sections.value + return values + + def is_valid(self): + """Checks whether split has compatible attributes with the hardware""" + if not check_valid_dtypes([self.input], supported_dtypes=[np.int8]): + return False + return True + + +def split_pattern(): + "Create the pattern for split" + split = is_op("split")(wildcard()) + return split + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1187,6 +1225,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal sigmoid_pattern(), lambda pat: SigmoidParams(pat).is_valid(), ), + ( + SplitParams.composite_name, + split_pattern(), + lambda pat: SplitParams(pat).is_valid(), + ), ] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 0707ec27ca27b..ce2efc7dc3f5a 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -929,5 +929,27 @@ def sigmoid_function(x): _compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type) +# This codegen test checks both, split and split_v +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), (1, 3, 4), 3), + ((4, 6, 8), 2, 0), + ((50,), 25, 0), + ((5, 11), 1, 1), + ((13,), (13,), 0), + ((22, 7), (4, -1), 1), + ], +) +def test_tflite_split(accel_type, ifm_shape, num_or_size_splits, axis): + @tf.function + def split_func(x): + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + _compare_tvm_with_tflite(split_func, [ifm_shape], accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 9dc94d96fb274..9f979153f714a 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -1344,5 +1344,163 @@ def representative_dataset(): assert tuple(func_body.args[1].checked_type.shape) == (256,) +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), 3, 2), + ((4, 6, 8), 2, 0), + ((5, 15), 3, 1), + ((3, 7), 1, 1), + ((100,), 25, 0), + ], +) +def test_tflite_split_legalize(ifm_shape, num_or_size_splits, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, num_or_size_splits, axis): + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + # dig out the split + single_output_split = num_or_size_splits == 1 + split = ( + ext_func.body.tuple_value + if single_output_split + else ext_func.body.args[0][0].args[0].tuple_value + ) + assert split.op.name == "split" + + # Split is specified by number of equal chunks + assert split.attrs.indices_or_sections == num_or_size_splits + + assert split.attrs.axis == axis + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = ethosu.partition_for_ethosu(mod) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), (1, 3, 4), 3), + ((10, 18, 4), (1, 4, 3, 2), 0), + ((22, 7), (4, -1), 1), + ((25,), (25,), 0), + ], +) +def test_tflite_split_v_legalize(ifm_shape, num_or_size_splits, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, num_or_size_splits, axis): + # TF split gets converted into TFLite's split_v + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + # dig out the split + single_output_split = len(num_or_size_splits) == 1 + split = ( + ext_func.body.tuple_value + if single_output_split + else ext_func.body.args[0][0].args[0].tuple_value + ) + assert split.op.name == "split" + + # Split is specified by the size of sections, so converting num_or_size_splits + # into the indices where the tensor is split at since this is how split is represented + # in Relay + split_sections = [] if single_output_split else [num_or_size_splits[0]] + for split_size in num_or_size_splits[1:-1]: + sec = split_sections[-1] + split_size + split_sections.append(sec) + assert list(split.attrs.indices_or_sections) == split_sections + + assert split.attrs.axis == axis + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = ethosu.partition_for_ethosu(mod) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__])