diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index aaac59ad4a52..4133aff6ef51 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -91,10 +91,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod) mod = ethosu_passes.HoistAllocates()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) - # MergeConstant pass currently does not support striped schedules. - # It requires further investigation. - if not util.is_striping_enabled(): - mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod) + mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod) mod = ethosu_passes.CopyComputeReordering()(mod) # When striping is enabled and if storage_rewrite is not run diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 2f6fa8f3ea33..d51ffbf833a4 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -514,7 +514,7 @@ class MergeConstantsMutator : public StmtExprMutator { // Make the new const dict Array> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)}; - Array> buffers_to_merge{ + Map> buffers_to_merge{ GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)}; Map new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)}; @@ -832,9 +832,11 @@ class MergeConstantsMutator : public StmtExprMutator { return vector; } - Array> GetArgsToMergeWithoutArgsNotInConstDict( + Map> GetArgsToMergeWithoutArgsNotInConstDict( const Array>& args_to_merge, const Map& const_dict) { - Array> new_args_to_merge{}; + Map> new_args_to_merge{}; + bool first_arg_found = false; + int64_t new_arg_key = 0; // the updated key of the merged const_dict for (Array args : args_to_merge) { IntImm key{args[0]}; auto it = std::find_if(const_dict.begin(), const_dict.end(), @@ -842,21 +844,29 @@ class MergeConstantsMutator : public StmtExprMutator { return pair.first->value == key->value; }); if (it != const_dict.end()) { - new_args_to_merge.push_back(args); + if (first_arg_found == false) { + first_arg_found = true; + new_arg_key = key->value; + } + new_args_to_merge.Set(IntImm(DataType::Int(64), new_arg_key), args); + } + if (first_arg_found) { + new_arg_key++; } } return new_args_to_merge; } - Map MakeNewConstDict(const Array>& args_to_merge, + Map MakeNewConstDict(const Map>& args_to_merge, Map const_dict) { Map new_const_dict{}; if (args_to_merge.size() == 0) { return new_const_dict; } - int64_t key = args_to_merge[0][0]->value; - for (Array args : args_to_merge) { + for (auto const& elem : args_to_merge) { + IntImm key = elem.first; + Array args = elem.second; int64_t size = 0; for (IntImm arg : args) { auto it = std::find_if(const_dict.begin(), const_dict.end(), @@ -876,8 +886,7 @@ class MergeConstantsMutator : public StmtExprMutator { arg_constant.CopyToBytes(static_cast(constant->data) + offset, nbytes); offset += nbytes; } - new_const_dict.Set(IntImm(DataType::Int(64), key), constant); - key += 1; + new_const_dict.Set(key, constant); } return new_const_dict; } diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 6ffbf22312ff..c751d44b6156 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -340,15 +340,15 @@ def _get_func(): @tvm.script.ir_module class MixedReadU55: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(112,), "uint8"]) -> None: + def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([112], "uint8") buffer3 = T.buffer_decl([112], "uint8") buffer5 = T.buffer_decl([112], "uint8") + buffer7 = T.buffer_decl([112], "uint8") buffer9 = T.buffer_decl([592], "uint8") buffer10 = T.buffer_decl([160], "uint8") - buffer11 = T.buffer_decl([2048], "int8") # body p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) p1 = T.buffer_decl([112], "uint8", data=p1_data) @@ -357,21 +357,21 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(112,) p2_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) p2 = T.buffer_decl([112], "uint8", data=p2_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 112, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 112, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 112, p2[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 112, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class MixedReadU65: @T.prim_func - def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"]) -> None: + def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) @@ -381,7 +381,7 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,) buffer3 = T.buffer_decl([128], dtype="uint8") buffer4 = T.buffer_decl([608], dtype="uint8") buffer5 = T.buffer_decl([160], dtype="uint8") - buffer6 = T.buffer_decl([2048], dtype="int8") + buffer6 = T.buffer_decl([128], dtype="uint8") p1_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) p1 = T.buffer_decl([128], "uint8", data=p1_data) p2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) @@ -389,14 +389,14 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,) p3_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) p3 = T.buffer_decl([128], "uint8", data=p3_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p3[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 128, p1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, p3[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 128, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index 337b5c70d125..a5adcfceac83 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -441,6 +441,195 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint check_const_dictionaries(const_dict, new_const_dict) +def test_arbitrary_argument_order(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + p1_data = T.allocate([368], "uint8", "global") + p1 = T.buffer_decl([368], "uint8", data=p1_data) + p2_data = T.allocate([96], "uint8", "global") + p2 = T.buffer_decl([96], "uint8", data=p2_data) + p3_data = T.allocate([368], "uint8", "global") + p3 = T.buffer_decl([368], "uint8", data=p3_data) + p4_data = T.allocate([96], "uint8", "global") + p4 = T.buffer_decl([96], "uint8", data=p4_data) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # body + p1_data = T.allocate([464], "uint8", "global") + p1 = T.buffer_decl([464], "uint8", data=p1_data) + p2_data = T.allocate([464], "uint8", "global") + p2 = T.buffer_decl([464], "uint8", data=p2_data) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + # fmt: on + + const_dict = { + 1: np.array([1], dtype=np.uint8), + 2: np.array([2], dtype=np.uint8), + 4: np.array([4], dtype=np.uint8), + 5: np.array([5], dtype=np.uint8), + } + new_const_dict = { + 1: np.concatenate((const_dict[1], const_dict[2])), + 3: np.concatenate((const_dict[4], const_dict[5])), + } + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, False) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_arbitrary_argument_order_const_split(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + p1_data = T.allocate([368], "uint8", "global") + p1 = T.buffer_decl([368], "uint8", data=p1_data) + p2_data = T.allocate([96], "uint8", "global") + p2 = T.buffer_decl([96], "uint8", data=p2_data) + p3_data = T.allocate([368], "uint8", "global") + p3 = T.buffer_decl([368], "uint8", data=p3_data) + p4_data = T.allocate([96], "uint8", "global") + p4 = T.buffer_decl([96], "uint8", data=p4_data) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # body + p1_data = T.allocate([464], "uint8", "global") + p1 = T.buffer_decl([464], "uint8", data=p1_data) + p2_data = T.allocate([464], "uint8", "global") + p2 = T.buffer_decl([464], "uint8", data=p2_data) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + # fmt: on + + const_dict = { + 1: np.array([1], dtype=np.uint8), + 3: np.array([3], dtype=np.uint8), + 4: np.array([4], dtype=np.uint8), + 5: np.array([5], dtype=np.uint8), + } + new_const_dict = { + 1: np.concatenate((const_dict[1], const_dict[3])), + 3: np.concatenate((const_dict[4], const_dict[5])), + } + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + +def test_arbitrary_argument_order_const_split_mixed(): + # fmt: off + @tvm.script.ir_module + class InputModule: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + p1_data = T.allocate([368], "uint8", "global") + p1 = T.buffer_decl([368], "uint8", data=p1_data) + p2_data = T.allocate([368], "uint8", "global") + p2 = T.buffer_decl([368], "uint8", data=p2_data) + p3_data = T.allocate([96], "uint8", "global") + p3 = T.buffer_decl([96], "uint8", data=p3_data) + p4_data = T.allocate([96], "uint8", "global") + p4 = T.buffer_decl([96], "uint8", data=p4_data) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 96, p3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p3[0], 48, p3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 368, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + + + @tvm.script.ir_module + class ReferenceModule: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # body + p1_data = T.allocate([464], "uint8", "global") + p1 = T.buffer_decl([464], "uint8", data=p1_data) + p2_data = T.allocate([464], "uint8", "global") + p2 = T.buffer_decl([464], "uint8", data=p2_data) + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + # fmt: on + + const_dict = { + 1: np.array([1], dtype=np.uint8), + 2: np.array([2], dtype=np.uint8), + 4: np.array([4], dtype=np.uint8), + 5: np.array([5], dtype=np.uint8), + } + new_const_dict = { + 1: np.concatenate((const_dict[1], const_dict[4])), + 2: np.concatenate((const_dict[2], const_dict[5])), + } + test_mod, const_dict = MergeConstants(const_dict)(InputModule) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod, reference_mod, True) + check_const_dictionaries(const_dict, new_const_dict) + + def test_cycle_count(): # fmt: off @tvm.script.ir_module