Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[microNPU] Fixed MergeConstants pass on striped networks (apache#13281)
Browse files Browse the repository at this point in the history
This PR fixes the bug in MergeConstants pass on striped networks on Ethos-U NPU.

The issue was caused by _DivideConstants_ pass which is introducing new mod parameters and changing their order. So ethosu_write parameter in some cases is moved from the end of the list to the middle.
E.g. from:
`[ethos-u_0_i0, p1, p2, p3, p4, p5, p6, ethosu_write]`
To:
`[ethos-u_0_i0, p1, p2, ethosu_write, placeholder, placeholder, placeholder, placeholder, placeholder, placeholder, placeholder, placeholder]`

Updated version of the  _GetArgsToMergeWithoutArgsNotInConstDict_ and _MakeNewConstDict_ methods in passes.cc can now correctly modify const_dict according to the new parameter list.
  • Loading branch information
sergio-grovety authored and xinetzone committed Nov 25, 2022
1 parent 01efb6b commit 0b3bab3
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 29 deletions.
5 changes: 1 addition & 4 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
mod = ethosu_passes.HoistAllocates()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
# MergeConstant pass currently does not support striped schedules.
# It requires further investigation.
if not util.is_striping_enabled():
mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
mod = ethosu_passes.CopyComputeReordering()(mod)

# When striping is enabled and if storage_rewrite is not run
Expand Down
27 changes: 18 additions & 9 deletions src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ class MergeConstantsMutator : public StmtExprMutator {

// Make the new const dict
Array<Array<IntImm>> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)};
Array<Array<IntImm>> buffers_to_merge{
Map<IntImm, Array<IntImm>> buffers_to_merge{
GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)};
Map<IntImm, runtime::NDArray> new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)};

Expand Down Expand Up @@ -832,31 +832,41 @@ class MergeConstantsMutator : public StmtExprMutator {
return vector;
}

Array<Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
Map<IntImm, Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
const Array<Array<IntImm>>& args_to_merge, const Map<IntImm, runtime::NDArray>& const_dict) {
Array<Array<IntImm>> new_args_to_merge{};
Map<IntImm, Array<IntImm>> new_args_to_merge{};
bool first_arg_found = false;
int64_t new_arg_key = 0; // the updated key of the merged const_dict
for (Array<IntImm> args : args_to_merge) {
IntImm key{args[0]};
auto it = std::find_if(const_dict.begin(), const_dict.end(),
[&](std::pair<tvm::IntImm, runtime::NDArray> pair) {
return pair.first->value == key->value;
});
if (it != const_dict.end()) {
new_args_to_merge.push_back(args);
if (first_arg_found == false) {
first_arg_found = true;
new_arg_key = key->value;
}
new_args_to_merge.Set(IntImm(DataType::Int(64), new_arg_key), args);
}
if (first_arg_found) {
new_arg_key++;
}
}
return new_args_to_merge;
}

Map<IntImm, runtime::NDArray> MakeNewConstDict(const Array<Array<IntImm>>& args_to_merge,
Map<IntImm, runtime::NDArray> MakeNewConstDict(const Map<IntImm, Array<IntImm>>& args_to_merge,
Map<IntImm, runtime::NDArray> const_dict) {
Map<IntImm, runtime::NDArray> new_const_dict{};
if (args_to_merge.size() == 0) {
return new_const_dict;
}

int64_t key = args_to_merge[0][0]->value;
for (Array<IntImm> args : args_to_merge) {
for (auto const& elem : args_to_merge) {
IntImm key = elem.first;
Array<IntImm> args = elem.second;
int64_t size = 0;
for (IntImm arg : args) {
auto it = std::find_if(const_dict.begin(), const_dict.end(),
Expand All @@ -876,8 +886,7 @@ class MergeConstantsMutator : public StmtExprMutator {
arg_constant.CopyToBytes(static_cast<uint8_t*>(constant->data) + offset, nbytes);
offset += nbytes;
}
new_const_dict.Set(IntImm(DataType::Int(64), key), constant);
key += 1;
new_const_dict.Set(key, constant);
}
return new_const_dict;
}
Expand Down
32 changes: 16 additions & 16 deletions tests/python/contrib/test_ethosu/test_encode_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,15 @@ def _get_func():
@tvm.script.ir_module
class MixedReadU55:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(112,), "uint8"]) -> None:
def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
buffer1 = T.buffer_decl([112], "uint8")
buffer3 = T.buffer_decl([112], "uint8")
buffer5 = T.buffer_decl([112], "uint8")
buffer7 = T.buffer_decl([112], "uint8")
buffer9 = T.buffer_decl([592], "uint8")
buffer10 = T.buffer_decl([160], "uint8")
buffer11 = T.buffer_decl([2048], "int8")
# body
p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True})
p1 = T.buffer_decl([112], "uint8", data=p1_data)
Expand All @@ -357,21 +357,21 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(112,)
p2_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True})
p2 = T.buffer_decl([112], "uint8", data=p2_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 112, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 112, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 112, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 112, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None


@tvm.script.ir_module
class MixedReadU65:
@T.prim_func
def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"]) -> None:
def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})

Expand All @@ -381,22 +381,22 @@ def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,)
buffer3 = T.buffer_decl([128], dtype="uint8")
buffer4 = T.buffer_decl([608], dtype="uint8")
buffer5 = T.buffer_decl([160], dtype="uint8")
buffer6 = T.buffer_decl([2048], dtype="int8")
buffer6 = T.buffer_decl([128], dtype="uint8")
p1_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True})
p1 = T.buffer_decl([128], "uint8", data=p1_data)
p2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
p2 = T.buffer_decl([4096], "int8", data=p2_data)
p3_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True})
p3 = T.buffer_decl([128], "uint8", data=p3_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 128, p1[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 128, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None
# fmt: on

Expand Down
Loading

0 comments on commit 0b3bab3

Please sign in to comment.