Skip to content

Commit

Permalink
[microNPU] Fix Cascader code generation without StorageRewrite (#13365)
Browse files Browse the repository at this point in the history
There were extra memory allocations for buffers when parts of the buffer for the result were replaced with a buffer for the entire result (in ReplaceOperators pass)
summing up we received a larger size in the number of parts
  • Loading branch information
Aleksei-grovety authored Nov 18, 2022
1 parent 53824d6 commit 37a8855
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
6 changes: 1 addition & 5 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,8 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
mod = ethosu_passes.CopyComputeReordering()(mod)

# When striping is enabled and if storage_rewrite is not run
# the striping results in incorrect code generation. This needs
# further investigation. Until such a time that is fixed, disable_storage_rewrite
# user directive will be overridden if striping is enabled.
disable_storage_rewrite = curr_cfg.get("tir.disable_storage_rewrite", False)
if not disable_storage_rewrite or util.is_striping_enabled():
if not disable_storage_rewrite:
mod = tvm.tir.transform.StorageRewrite()(mod)

mod = tvm.tir.transform.RemoveNoOp()(mod)
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def ReplaceOperators():
producers_consumers = ProducersConsumers()
replace_output_pointer = {}
pointer_to_extents = {}
replaced_pointers = []

ReplaceInfo = namedtuple("ReplaceInfo", ["pointer", "reallocate"])

Expand Down Expand Up @@ -136,9 +137,13 @@ def _replace_operator(stmt):
stmt, producers_consumers
)
if replace_pointer is not None:
# Allocate pointer only once
if replace_pointer in replaced_pointers:
is_allocator = False
replace_output_pointer[output_pointer] = ReplaceInfo(
replace_pointer, is_allocator
)
replaced_pointers.append(replace_pointer)
# Make the extern call
irb = tvm.tir.ir_builder.create()
irb.emit(tvm.tir.call_extern("handle", op_name, *info))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ def tf_graph(x):
@pytest.mark.parametrize(
"accel_type, expected_ws_size_without_striping, expected_ws_size_with_striping",
[
("ethos-u55-256", 180288, 15312),
("ethos-u55-128", 180288, 15312),
("ethos-u55-64", 180288, 14544),
("ethos-u55-32", 180272, 14544),
("ethos-u55-256", 180288, 15200),
("ethos-u55-128", 180288, 15200),
("ethos-u55-64", 180288, 14432),
("ethos-u55-32", 180272, 14416),
],
)
def test_depthwise2d_conv2d_pooling(
Expand Down

0 comments on commit 37a8855

Please sign in to comment.