From ecb3c3feae7c4b7bbba5082585297612d5d9889c Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne@arm.com>
Date: Thu, 8 Jul 2021 13:44:25 +0100
Subject: [PATCH 1/8] [TIR][USMP] greedy_by_size usmp algo

* Implementation of greedy by size memory planning algorithm
* Added a test case of linear sequence of operators with two pools
* Added a test case with residual structures

Change-Id: I03b41292eab85ddb43710356c23dd123beb24462
---
 src/tir/usmp/algo/greedy_by_size.cc           | 128 ++++++
 .../test_tir_usmp_algo_greedy_by_size.py      | 370 ++++++++++++++++++
 2 files changed, 498 insertions(+)
 create mode 100644 src/tir/usmp/algo/greedy_by_size.cc
 create mode 100644 tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py

diff --git a/src/tir/usmp/algo/greedy_by_size.cc b/src/tir/usmp/algo/greedy_by_size.cc
new file mode 100644
index 000000000000..c657ad41a82e
--- /dev/null
+++ b/src/tir/usmp/algo/greedy_by_size.cc
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/algo/greedy_by_size.cc
+ * \brief Implement greedy by size memory planning algorithm
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/usmp/utils.h>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+namespace algo {
+
+size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
+                                  const int& byte_alignment) {
+  return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment;
+}
+
+bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
+                      const size_t& size_bytes) {
+  if (candidate_pool->size_hint_bytes == -1) {
+    // this means pool is not bounded
+    return true;
+  }
+  auto pool_size = static_cast<size_t>(candidate_pool->size_hint_bytes->value);
+  auto max_address = next_offset + size_bytes;
+  if (max_address <= pool_size) {
+    return true;
+  }
+  return false;
+}
+
+PoolInfo SelectPlacementPool(
+    const Array<PoolInfo>& pool_candidates,
+    const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
+  for (const auto& pool_info : pool_candidates) {
+    if (pool_offsets.count(pool_info)) {
+      return pool_info;
+    }
+  }
+  ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!";
+  return PoolInfo();
+}
+
+Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) {
+  std::vector<BufferInfo> buffer_info_vec;
+  Map<BufferInfo, PoolAllocation> pool_allocations;
+  for (const auto& buffer_info : buffer_info_arr) {
+    buffer_info_vec.push_back(std::move(buffer_info));
+  }
+  std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
+            [](const BufferInfo& a, const BufferInfo& b) {
+              if (a->size_bytes->value == b->size_bytes->value) {
+                if (a->conflicts.size() == b->conflicts.size()) {
+                  auto a_name_hash = std::hash<std::string>{}(a->name_hint->data);
+                  auto b_name_hash = std::hash<std::string>{}(b->name_hint->data);
+                  return a_name_hash > b_name_hash;
+                } else {
+                  return a->conflicts.size() > b->conflicts.size();
+                }
+              }
+              return a->size_bytes > b->size_bytes;
+            });
+
+  for (const auto& buf_info : buffer_info_vec) {
+    std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates;
+    for (const auto& pool_info : buf_info->pool_candidates) {
+      if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) {
+        pool_offset_candidates[pool_info] = 0;
+      }
+    }
+
+    for (const auto& conflict_buf_info_obj : buf_info->conflicts) {
+      auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj);
+      size_t next_offset = 0;
+      if (pool_allocations.count(conflict_buf_info)) {
+        auto pool_allocation = pool_allocations[conflict_buf_info];
+        next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes;
+        next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value);
+        if (IsValidPlacement(pool_allocation->pool_info, next_offset,
+                             buf_info->size_bytes->value)) {
+          if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) {
+            pool_offset_candidates[pool_allocation->pool_info] = next_offset;
+          }
+        } else {
+          pool_offset_candidates.erase(pool_allocation->pool_info);
+        }
+      }
+    }
+    auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates);
+    pool_allocations.Set(
+        buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])));
+  }
+  return pool_allocations;
+}
+
+TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size")
+    .set_body_typed([](Array<BufferInfo> buffer_info_arr) {
+      return GreedyBySize(buffer_info_arr);
+    });
+
+}  // namespace algo
+}  // namespace usmp
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py b/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py
new file mode 100644
index 000000000000..a24a79ba85ef
--- /dev/null
+++ b/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py
@@ -0,0 +1,370 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+from tvm import tir, script
+from tvm.script import tir as T
+from tvm.tir import stmt_functor
+from tvm.tir.usmp import utils as usmp_utils
+from tvm.target import Target
+
+
+def _replace_stmt_with_buf_var_names(buffer_info_map):
+    """helper to replace tir.allocates with buffer names"""
+    new_buffer_info_map = dict()
+    for k, v in buffer_info_map.items():
+        new_buffer_info_map[v.buffer_var.name] = k
+    return new_buffer_info_map
+
+
+def _verify_conflicts(main_buf_name, conflicting_buf_names, buffer_info_map):
+    """helper to check expected liveness conflicts"""
+    buf_info = buffer_info_map[main_buf_name]
+    for conflict in buf_info.conflicts:
+        assert conflict.name_hint in conflicting_buf_names
+
+
+def _get_allocates(primfunc):
+    """helper to extract all allocate nodes by name"""
+    allocates = dict()
+
+    def get_allocate(stmt):
+        if isinstance(stmt, tvm.tir.Allocate):
+            allocates[str(stmt.buffer_var.name)] = stmt
+
+    stmt_functor.post_order_visit(primfunc.body, get_allocate)
+    return allocates
+
+
+def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
+    """helper to assing poolinfos to allocate nodes in a tir.PrimFunc"""
+
+    def set_poolinfos(stmt):
+        if isinstance(stmt, tvm.tir.Allocate):
+            return tvm.tir.Allocate(
+                buffer_var=stmt.buffer_var,
+                dtype=stmt.dtype,
+                extents=stmt.extents,
+                condition=stmt.condition,
+                body=stmt.body,
+                annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos},
+            )
+
+    return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos))
+
+
+def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
+    """helper to assing poolinfos to allocate nodes in a IRModule"""
+    ret = tvm.IRModule()
+    for global_var, basefunc in mod.functions.items():
+        if isinstance(basefunc, tvm.tir.PrimFunc):
+            ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos)
+    return ret
+
+
+def _check_max_workspace_size(buffer_pool_allocations, pool_info, size):
+    max_workspace_size = 0
+    for buffer_info, pool_allocation in buffer_pool_allocations.items():
+        if pool_allocation.pool_info == pool_info:
+            size_candidate = pool_allocation.byte_offset + buffer_info.size_bytes
+            if size_candidate > max_workspace_size:
+                max_workspace_size = size_candidate
+    assert max_workspace_size == size
+
+
+# fmt: off
+@tvm.script.ir_module
+class LinearStructure:
+    @T.prim_func
+    def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True})
+        placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        for ax0_ax1_fused_1 in T.serial(0, 224):
+            for ax2_1, ax3_inner_1 in T.grid(224, 3):
+                T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True})
+        placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1)
+        T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        PaddedInput_7 = T.allocate([157323], "int16", "global")
+        for i0_i1_fused_7 in T.serial(0, 229):
+            for i2_7, i3_7 in T.grid(229, 3):
+                T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True)
+        for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544):
+            Conv2dOutput_7 = T.allocate([64], "int32", "global")
+            for ff_3 in T.serial(0, 64):
+                T.store(Conv2dOutput_7, ff_3, 0, True)
+                for ry_2, rx_2, rc_7 in T.grid(7, 7, 3):
+                    T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True)
+            for ax3_inner_7 in T.serial(0, 64):
+                T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True})
+        placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        tensor_2 = T.allocate([200704], "uint8", "global")
+        for ax0_ax1_fused_4 in T.serial(0, 56):
+            for ax2_4 in T.serial(0, 56):
+                for ax3_init in T.serial(0, 64):
+                    T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True)
+                for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
+                    T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True)
+        for ax0_ax1_fused_5 in T.serial(0, 56):
+            for ax2_5, ax3_3 in T.grid(56, 64):
+                T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True)
+
+    @T.prim_func
+    def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
+        # body
+        T.attr("default", "device_id", 0)
+        T.attr("default", "device_type", 1)
+        sid_9 = T.allocate([301056], "int8", "global")
+        sid_8 = T.allocate([802816], "int8", "global")
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+def test_linear():
+    fast_memory_pool = usmp_utils.PoolInfo(
+        pool_name="fast_memory",
+        target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+        size_hint_bytes=200704,
+    )
+    slow_memory_pool = usmp_utils.PoolInfo(
+        pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}
+    )
+    tir_mod = LinearStructure
+    tir_mod = assign_poolinfos_to_allocates_in_irmodule(
+        tir_mod, [fast_memory_pool, slow_memory_pool]
+    )
+    main_func = tir_mod["tvmgen_default_run_model"]
+    buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+
+    fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
+    buffer_info_arr = fcreate_array_bi(buffer_info_map)
+    fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
+    buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr)
+
+    buffer_info_map_names = dict()
+    for buf_info in buffer_info_arr:
+        buffer_info_map_names[buf_info.name_hint] = buf_info
+
+    # check conflicts
+    _verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map_names)
+    _verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map_names)
+    _verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map_names)
+    _verify_conflicts("tensor_2", ["sid_8"], buffer_info_map_names)
+    _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map_names)
+
+    _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 802816)
+    _check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, 200704)
+
+
+# fmt: off
+@tvm.script.ir_module
+class ResnetStructure:
+    @T.prim_func
+    def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8")
+        placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
+        T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16")
+        # body
+        for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
+            T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True})
+        placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16")
+        placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16")
+        placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32")
+        T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16")
+        # body
+        PaddedInput_1 = T.allocate([379456], "int16", "global")
+        for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64):
+            T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True)
+        for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625):
+            Conv2dOutput_1 = T.allocate([64], "int32", "global")
+            for ff_1 in T.serial(0, 64):
+                T.store(Conv2dOutput_1, ff_1, 0, True)
+                for ry, rx, rc_1 in T.grid(3, 3, 64):
+                    T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True)
+            for ax3_inner_2 in T.serial(0, 64):
+                T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True})
+        placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16")
+        placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16")
+        placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32")
+        T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32")
+        # body
+        PaddedInput_2 = T.allocate([360000], "int16", "global")
+        for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64):
+            T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True)
+        for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625):
+            Conv2dOutput_2 = T.allocate([64], "int32", "global")
+            for ax3_outer_1 in T.serial(0, 4):
+                for ff_2 in T.serial(0, 64):
+                    T.store(Conv2dOutput_2, ff_2, 0, True)
+                    for rc_2 in T.serial(0, 64):
+                        T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True)
+                for ax3_inner_3 in T.serial(0, 64):
+                    T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True})
+        placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16")
+        placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16")
+        placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32")
+        placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32")
+        T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8")
+        # body
+        PaddedInput_3 = T.allocate([360000], "int16", "global")
+        for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64):
+            T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True)
+        for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625):
+            Conv2dOutput_3 = T.allocate([64], "int32", "global")
+            for ax3_outer_2 in T.serial(0, 4):
+                for ff_3 in T.serial(0, 64):
+                    T.store(Conv2dOutput_3, ff_3, 0, True)
+                    for rc_3 in T.serial(0, 64):
+                        T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True)
+                for ax3_inner_4 in T.serial(0, 64):
+                    T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True)
+
+    @T.prim_func
+    def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
+        # body
+        T.attr("default", "device_id", 0)
+        T.attr("default", "device_type", 1)
+        sid_2 = T.allocate([720000], "int8", "global")
+        sid_6 = T.allocate([5760000], "int8", "global")
+        sid_7 = T.allocate([720000], "int8", "global")
+        sid_8 = T.allocate([720000], "int8", "global")
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32"))
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True})
+        placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16")
+        placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16")
+        placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32")
+        T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
+        # body
+        PaddedInput = T.allocate([360000], "int16", "global")
+        for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
+            T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True)
+        for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
+            Conv2dOutput = T.allocate([64], "int32", "global")
+            for ff in T.serial(0, 64):
+                T.store(Conv2dOutput, ff, 0, True)
+                for rc in T.serial(0, 64):
+                    T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True)
+            for ax3_inner_1 in T.serial(0, 64):
+                T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)
+    __tvm_meta__ = None
+# fmt: on
+
+
+def test_fanout():
+    global_workspace_pool = usmp_utils.PoolInfo(
+        pool_name="global_workspace",
+        target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+    )
+    tir_mod = ResnetStructure
+    tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
+    main_func = tir_mod["tvmgen_default_run_model"]
+    buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+
+    fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
+    buffer_info_arr = fcreate_array_bi(buffer_info_map)
+    fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
+    buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr)
+
+    buffer_info_map_names = dict()
+    for buf_info in buffer_info_arr:
+        buffer_info_map_names[buf_info.name_hint] = buf_info
+
+    # check conflicts
+    _verify_conflicts(
+        "sid_6",
+        ["Conv2dOutput_2", "sid_2", "PaddedInput_3", "Conv2dOutput_3"],
+        buffer_info_map_names,
+    )
+    _verify_conflicts("PaddedInput_1", ["sid_8", "sid_2", "Conv2dOutput_1"], buffer_info_map_names)
+    _verify_conflicts("PaddedInput_2", ["sid_7", "sid_2", "Conv2dOutput_2"], buffer_info_map_names)
+    _verify_conflicts("sid_8", ["Conv2dOutput", "sid_2", "PaddedInput_1"], buffer_info_map_names)
+    _verify_conflicts(
+        "sid_2",
+        [
+            "PaddedInput",
+            "Conv2dOutput",
+            "sid_8",
+            "PaddedInput_1",
+            "Conv2dOutput_1",
+            "sid_7",
+            "PaddedInput_2",
+            "Conv2dOutput_2",
+            "sid_6",
+            "PaddedInput_3",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts("PaddedInput", ["sid_2", "Conv2dOutput"], buffer_info_map_names)
+    _verify_conflicts("sid_7", ["Conv2dOutput_1", "sid_2", "PaddedInput_2"], buffer_info_map_names)
+    _verify_conflicts("PaddedInput_3", ["sid_6", "sid_2", "Conv2dOutput_3"], buffer_info_map_names)
+    _verify_conflicts("Conv2dOutput_3", ["PaddedInput_3", "sid_6"], buffer_info_map_names)
+    _verify_conflicts("Conv2dOutput", ["PaddedInput", "sid_2", "sid_8"], buffer_info_map_names)
+    _verify_conflicts("Conv2dOutput_1", ["PaddedInput_1", "sid_2", "sid_7"], buffer_info_map_names)
+    _verify_conflicts("Conv2dOutput_2", ["PaddedInput_2", "sid_2", "sid_6"], buffer_info_map_names)
+
+    _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7920256)

From e0ca1d91f14e80c2ad3042e0f642d7fb81fed75a Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne@arm.com>
Date: Mon, 22 Nov 2021 14:25:41 +0000
Subject: [PATCH 2/8] [TIR][USMP] greedy_by_size usmp algo

* Adding targets to the PrimFuncs in the tests

Change-Id: Ic91947e23cbcc4fc0020eb62f0ed9df26cf919f9
---
 src/tir/usmp/algo/greedy_by_size.cc           |  10 +-
 .../test_tir_usmp_algo_greedy_by_size.py      | 140 ++++++++++++++----
 2 files changed, 119 insertions(+), 31 deletions(-)

diff --git a/src/tir/usmp/algo/greedy_by_size.cc b/src/tir/usmp/algo/greedy_by_size.cc
index c657ad41a82e..7aafe8c0974f 100644
--- a/src/tir/usmp/algo/greedy_by_size.cc
+++ b/src/tir/usmp/algo/greedy_by_size.cc
@@ -34,13 +34,13 @@ namespace tir {
 namespace usmp {
 namespace algo {
 
-size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
-                                  const int& byte_alignment) {
+static size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
+                                         const int& byte_alignment) {
   return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment;
 }
 
-bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
-                      const size_t& size_bytes) {
+static bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
+                             const size_t& size_bytes) {
   if (candidate_pool->size_hint_bytes == -1) {
     // this means pool is not bounded
     return true;
@@ -53,7 +53,7 @@ bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
   return false;
 }
 
-PoolInfo SelectPlacementPool(
+static PoolInfo SelectPlacementPool(
     const Array<PoolInfo>& pool_candidates,
     const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
   for (const auto& pool_info : pool_candidates) {
diff --git a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py b/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py
index a24a79ba85ef..6bd7832f533b 100644
--- a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py
+++ b/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py
@@ -77,6 +77,15 @@ def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
     return ret
 
 
+def _assign_targets_to_primfuncs_irmodule(mod, target):
+    """helper to assign target for PrimFunc in a IRModule"""
+    ret = tvm.IRModule()
+    for global_var, basefunc in mod.functions.items():
+        if isinstance(basefunc, tvm.tir.PrimFunc):
+            ret[global_var] = basefunc.with_attr("target", target)
+    return ret
+
+
 def _check_max_workspace_size(buffer_pool_allocations, pool_info, size):
     max_workspace_size = 0
     for buffer_info, pool_allocation in buffer_pool_allocations.items():
@@ -143,7 +152,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6:
                 T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True)
 
     @T.prim_func
-    def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:
+    def run_model(input: T.handle, output: T.handle) -> None:
         # function attr dict
         T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
         # body
@@ -159,19 +168,21 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:
 
 
 def test_linear():
+    target = Target("c")
     fast_memory_pool = usmp_utils.PoolInfo(
         pool_name="fast_memory",
-        target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+        target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
         size_hint_bytes=200704,
     )
     slow_memory_pool = usmp_utils.PoolInfo(
-        pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}
+        pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}
     )
     tir_mod = LinearStructure
+    tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
     tir_mod = assign_poolinfos_to_allocates_in_irmodule(
         tir_mod, [fast_memory_pool, slow_memory_pool]
     )
-    main_func = tir_mod["tvmgen_default_run_model"]
+    main_func = tir_mod["run_model"]
     buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
 
     fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
@@ -184,13 +195,15 @@ def test_linear():
         buffer_info_map_names[buf_info.name_hint] = buf_info
 
     # check conflicts
-    _verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map_names)
-    _verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map_names)
-    _verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map_names)
+    _verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map_names)
     _verify_conflicts("tensor_2", ["sid_8"], buffer_info_map_names)
     _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map_names)
+    _verify_conflicts(
+        "sid_8", ["PaddedInput_7", "Conv2dOutput_7", "tensor_2"], buffer_info_map_names
+    )
+    _verify_conflicts("Conv2dOutput_7", ["sid_8", "PaddedInput_7"], buffer_info_map_names)
 
-    _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 802816)
+    _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 1418528)
     _check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, 200704)
 
 
@@ -316,11 +329,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place
 
 
 def test_fanout():
+    target = Target("c")
     global_workspace_pool = usmp_utils.PoolInfo(
         pool_name="global_workspace",
-        target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+        target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
     )
     tir_mod = ResnetStructure
+    tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
     tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
     main_func = tir_mod["tvmgen_default_run_model"]
     buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
@@ -336,35 +351,108 @@ def test_fanout():
 
     # check conflicts
     _verify_conflicts(
-        "sid_6",
-        ["Conv2dOutput_2", "sid_2", "PaddedInput_3", "Conv2dOutput_3"],
+        "Conv2dOutput_1",
+        [
+            "PaddedInput_1",
+            "sid_7",
+        ],
         buffer_info_map_names,
     )
-    _verify_conflicts("PaddedInput_1", ["sid_8", "sid_2", "Conv2dOutput_1"], buffer_info_map_names)
-    _verify_conflicts("PaddedInput_2", ["sid_7", "sid_2", "Conv2dOutput_2"], buffer_info_map_names)
-    _verify_conflicts("sid_8", ["Conv2dOutput", "sid_2", "PaddedInput_1"], buffer_info_map_names)
     _verify_conflicts(
-        "sid_2",
+        "sid_8",
         [
             "PaddedInput",
             "Conv2dOutput",
-            "sid_8",
             "PaddedInput_1",
-            "Conv2dOutput_1",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts(
+        "PaddedInput_2",
+        [
             "sid_7",
+            "sid_6",
+            "Conv2dOutput_2",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts(
+        "sid_2",
+        [
+            "PaddedInput",
+            "PaddedInput_3",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts(
+        "Conv2dOutput",
+        [
+            "sid_8",
+            "PaddedInput",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts(
+        "sid_7",
+        [
+            "Conv2dOutput_1",
+            "PaddedInput_1",
+            "PaddedInput_2",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts(
+        "sid_6",
+        [
             "PaddedInput_2",
             "Conv2dOutput_2",
+            "Conv2dOutput_3",
+            "PaddedInput_3",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts(
+        "PaddedInput_3",
+        [
+            "sid_2",
+            "Conv2dOutput_3",
             "sid_6",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts(
+        "Conv2dOutput_3",
+        [
             "PaddedInput_3",
+            "sid_6",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts(
+        "PaddedInput",
+        [
+            "sid_2",
+            "sid_8",
+            "Conv2dOutput",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts(
+        "Conv2dOutput_2",
+        [
+            "sid_6",
+            "PaddedInput_2",
         ],
         buffer_info_map_names,
     )
-    _verify_conflicts("PaddedInput", ["sid_2", "Conv2dOutput"], buffer_info_map_names)
-    _verify_conflicts("sid_7", ["Conv2dOutput_1", "sid_2", "PaddedInput_2"], buffer_info_map_names)
-    _verify_conflicts("PaddedInput_3", ["sid_6", "sid_2", "Conv2dOutput_3"], buffer_info_map_names)
-    _verify_conflicts("Conv2dOutput_3", ["PaddedInput_3", "sid_6"], buffer_info_map_names)
-    _verify_conflicts("Conv2dOutput", ["PaddedInput", "sid_2", "sid_8"], buffer_info_map_names)
-    _verify_conflicts("Conv2dOutput_1", ["PaddedInput_1", "sid_2", "sid_7"], buffer_info_map_names)
-    _verify_conflicts("Conv2dOutput_2", ["PaddedInput_2", "sid_2", "sid_6"], buffer_info_map_names)
-
-    _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7920256)
+    _verify_conflicts(
+        "PaddedInput_1",
+        [
+            "sid_8",
+            "Conv2dOutput_1",
+            "sid_7",
+        ],
+        buffer_info_map_names,
+    )
+
+    _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7200000)

From 26fffb0fbfac1558bef905dee1817ba2e5eeabdf Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne@arm.com>
Date: Wed, 24 Nov 2021 18:35:09 +0000
Subject: [PATCH 3/8] [TIR][USMP] Greedy algorithms for USMP

This commit implements greedy algorithms for
USMP based on size and number of liveness
conflicts.

* This includes a slight fix for buffer info
  extraction where non-linear network buffers
  owned by the main function should not show
  sporadic liveness.

Change-Id: I957d543e75b3b0bcf5fc1fbc7870705c875c7a03
---
 src/tir/usmp/algo/greedy.cc                   | 235 +++++++
 src/tir/usmp/algo/greedy_by_size.cc           | 128 ----
 src/tir/usmp/analysis/extract_buffer_info.cc  | 134 ++--
 ...reedy_by_size.py => test_tir_usmp_algo.py} | 116 ++--
 ...st_tir_usmp_analysis_extract_bufferinfo.py | 574 +++++++++++-------
 5 files changed, 746 insertions(+), 441 deletions(-)
 create mode 100644 src/tir/usmp/algo/greedy.cc
 delete mode 100644 src/tir/usmp/algo/greedy_by_size.cc
 rename tests/python/unittest/{test_tir_usmp_algo_greedy_by_size.py => test_tir_usmp_algo.py} (94%)

diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc
new file mode 100644
index 000000000000..f0b1581cd616
--- /dev/null
+++ b/src/tir/usmp/algo/greedy.cc
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/algo/greedy_by_size.cc
+ * \brief This source contains greedy algorithms for planning
+ * memory for USMP. There are two algorithms present here :
+ * 1) greedy_by_size and 2) greedy_by_conflicts.
+ *
+ * greedy_by_size : this algorithm prioritizes placing the
+ * largest size buffer to the given pools. The BufferInfo objects
+ * are sorted based on the size and placed on each pool adhering
+ * to size_hint constraint.
+ *
+ * greedy_by_conflicts : this algorithm prioritizes placing the
+ * the most liveness conflicted buffer to the given pools. The
+ * BufferInfo objects are sorted based on the number of conflicts
+ * and placed on each pool adhering to size_hint constraint.
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/usmp/utils.h>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+namespace algo {
+
+/*!
+ * \brief This is the base class for Greedy Algorithms where the sorting
+ * is specialized in the extended classes based on the greedy criteria.
+ */
+class GreedyBase {
+ public:
+  GreedyBase() {}
+  /*!
+   * \brief This function should be implemented by the extended classes to sort the BufferInfo
+   * objects based on a criteria and then calling PostSortAllocation.
+   */
+  virtual Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) = 0;
+
+ protected:
+  /*!
+   * \brief Rounds up the offset to satisfy the alignement requirement
+   */
+  size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
+                                    const int& byte_alignment) {
+    return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment;
+  }
+
+  /*!
+   * \brief A helper function check whether a offset is valid given the constraints
+   */
+  bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
+                        const size_t& size_bytes) {
+    if (candidate_pool->size_hint_bytes == -1) {
+      // this means pool is not bounded
+      return true;
+    }
+    auto pool_size = static_cast<size_t>(candidate_pool->size_hint_bytes->value);
+    auto max_address = next_offset + size_bytes;
+    if (max_address <= pool_size) {
+      return true;
+    }
+    return false;
+  }
+
+  /*!
+   * \brief Selects a pool for placement in the given set of ordered pool candidates
+   */
+  PoolInfo SelectPlacementPool(
+      const Array<PoolInfo>& pool_candidates,
+      const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
+    // Here the pool candidates are ordered when it is consumed by the algorithm.
+    // This could be from order the user has specified. However, schedulers are
+    // welcome to change the order for performance reasons.
+    for (const auto& pool_info : pool_candidates) {
+      if (pool_offsets.count(pool_info)) {
+        return pool_info;
+      }
+    }
+    ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!";
+    return PoolInfo();
+  }
+
+  /*!
+   * \brief This is the base allocation function that works on sorted BufferInfo objects based
+   * on the greedy heuristic. The sorting algorithm has to be called before calling this.
+   */
+  Map<BufferInfo, PoolAllocation> PostSortAllocation(
+      const std::vector<BufferInfo>& buffer_info_vec) {
+    Map<BufferInfo, PoolAllocation> pool_allocations;
+    for (const auto& buf_info : buffer_info_vec) {
+      std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates;
+      for (const auto& pool_info : buf_info->pool_candidates) {
+        // Mark pool candidates that satisfy the size constraints.
+        if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) {
+          pool_offset_candidates[pool_info] = 0;
+        }
+      }
+
+      for (const auto& conflict_buf_info_obj : buf_info->conflicts) {
+        auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj);
+        size_t next_offset = 0;
+        // We only look at already allocated BufferInfo in-terms of conflicts.
+        if (pool_allocations.count(conflict_buf_info)) {
+          auto pool_allocation = pool_allocations[conflict_buf_info];
+          next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes;
+          next_offset =
+              round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value);
+          // Checks whether the next offset in the same pool as the conflicting BufferInfo is valid.
+          if (IsValidPlacement(pool_allocation->pool_info, next_offset,
+                               buf_info->size_bytes->value)) {
+            // There could be multiple conflicting BufferInfo in the same pool.
+            // Thus, we need to make sure we pick the largest offset of them all.
+            if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) {
+              pool_offset_candidates[pool_allocation->pool_info] = next_offset;
+            }
+          } else {
+            pool_offset_candidates.erase(pool_allocation->pool_info);
+          }
+        }
+      }
+      auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates);
+      pool_allocations.Set(
+          buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])));
+    }
+    return pool_allocations;
+  }
+};
+
+/*!
+ * \brief This class implements Greedy by the size of BufferInfo
+ * greedy algorithm. Please refer to main documentation of the file
+ * for more details.
+ */
+class GreedySize : public GreedyBase {
+ public:
+  GreedySize() {}
+  Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) {
+    std::vector<BufferInfo> buffer_info_vec;
+    Map<BufferInfo, PoolAllocation> pool_allocations;
+    for (const auto& buffer_info : buffer_info_arr) {
+      buffer_info_vec.push_back(std::move(buffer_info));
+    }
+    std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
+              [](const BufferInfo& a, const BufferInfo& b) {
+                if (a->size_bytes->value == b->size_bytes->value) {
+                  if (a->conflicts.size() == b->conflicts.size()) {
+                    auto a_name_hash = std::hash<std::string>{}(a->name_hint->data);
+                    auto b_name_hash = std::hash<std::string>{}(b->name_hint->data);
+                    return a_name_hash > b_name_hash;
+                  } else {
+                    return a->conflicts.size() > b->conflicts.size();
+                  }
+                }
+                return a->size_bytes > b->size_bytes;
+              });
+    return PostSortAllocation(buffer_info_vec);
+  }
+};
+
+/*!
+ * \brief This class implements Greedy by the number of conflicts of
+ * BufferInfo greedy algorithm. Please refer to main documentation
+ * of the file for more details.
+ */
+class GreedyConflicts : public GreedyBase {
+ public:
+  GreedyConflicts() {}
+  Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) {
+    std::vector<BufferInfo> buffer_info_vec;
+    Map<BufferInfo, PoolAllocation> pool_allocations;
+    for (const auto& buffer_info : buffer_info_arr) {
+      buffer_info_vec.push_back(std::move(buffer_info));
+    }
+    std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
+              [](const BufferInfo& a, const BufferInfo& b) {
+                if (a->conflicts.size() == b->conflicts.size()) {
+                  if (a->size_bytes->value == b->size_bytes->value) {
+                    auto a_name_hash = std::hash<std::string>{}(a->name_hint->data);
+                    auto b_name_hash = std::hash<std::string>{}(b->name_hint->data);
+                    return a_name_hash > b_name_hash;
+                  } else {
+                    return a->size_bytes->value > b->size_bytes->value;
+                  }
+                }
+                return a->conflicts.size() > b->conflicts.size();
+              });
+    return PostSortAllocation(buffer_info_vec);
+  }
+};
+
+Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) {
+  return GreedySize().PlanMemory(buffer_info_arr);
+}
+
+Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr) {
+  return GreedyConflicts().PlanMemory(buffer_info_arr);
+}
+
+TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size")
+    .set_body_typed([](Array<BufferInfo> buffer_info_arr) {
+      return GreedyBySize(buffer_info_arr);
+    });
+
+TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_conflicts")
+    .set_body_typed([](Array<BufferInfo> buffer_info_arr) {
+      return GreedyByConflicts(buffer_info_arr);
+    });
+
+}  // namespace algo
+}  // namespace usmp
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/usmp/algo/greedy_by_size.cc b/src/tir/usmp/algo/greedy_by_size.cc
deleted file mode 100644
index 7aafe8c0974f..000000000000
--- a/src/tir/usmp/algo/greedy_by_size.cc
+++ /dev/null
@@ -1,128 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tir/analysis/usmp/algo/greedy_by_size.cc
- * \brief Implement greedy by size memory planning algorithm
- */
-
-#include <tvm/arith/analyzer.h>
-#include <tvm/runtime/device_api.h>
-#include <tvm/tir/builtin.h>
-#include <tvm/tir/function.h>
-#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/usmp/utils.h>
-
-namespace tvm {
-namespace tir {
-namespace usmp {
-namespace algo {
-
-static size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
-                                         const int& byte_alignment) {
-  return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment;
-}
-
-static bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
-                             const size_t& size_bytes) {
-  if (candidate_pool->size_hint_bytes == -1) {
-    // this means pool is not bounded
-    return true;
-  }
-  auto pool_size = static_cast<size_t>(candidate_pool->size_hint_bytes->value);
-  auto max_address = next_offset + size_bytes;
-  if (max_address <= pool_size) {
-    return true;
-  }
-  return false;
-}
-
-static PoolInfo SelectPlacementPool(
-    const Array<PoolInfo>& pool_candidates,
-    const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
-  for (const auto& pool_info : pool_candidates) {
-    if (pool_offsets.count(pool_info)) {
-      return pool_info;
-    }
-  }
-  ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!";
-  return PoolInfo();
-}
-
-Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) {
-  std::vector<BufferInfo> buffer_info_vec;
-  Map<BufferInfo, PoolAllocation> pool_allocations;
-  for (const auto& buffer_info : buffer_info_arr) {
-    buffer_info_vec.push_back(std::move(buffer_info));
-  }
-  std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
-            [](const BufferInfo& a, const BufferInfo& b) {
-              if (a->size_bytes->value == b->size_bytes->value) {
-                if (a->conflicts.size() == b->conflicts.size()) {
-                  auto a_name_hash = std::hash<std::string>{}(a->name_hint->data);
-                  auto b_name_hash = std::hash<std::string>{}(b->name_hint->data);
-                  return a_name_hash > b_name_hash;
-                } else {
-                  return a->conflicts.size() > b->conflicts.size();
-                }
-              }
-              return a->size_bytes > b->size_bytes;
-            });
-
-  for (const auto& buf_info : buffer_info_vec) {
-    std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates;
-    for (const auto& pool_info : buf_info->pool_candidates) {
-      if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) {
-        pool_offset_candidates[pool_info] = 0;
-      }
-    }
-
-    for (const auto& conflict_buf_info_obj : buf_info->conflicts) {
-      auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj);
-      size_t next_offset = 0;
-      if (pool_allocations.count(conflict_buf_info)) {
-        auto pool_allocation = pool_allocations[conflict_buf_info];
-        next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes;
-        next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value);
-        if (IsValidPlacement(pool_allocation->pool_info, next_offset,
-                             buf_info->size_bytes->value)) {
-          if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) {
-            pool_offset_candidates[pool_allocation->pool_info] = next_offset;
-          }
-        } else {
-          pool_offset_candidates.erase(pool_allocation->pool_info);
-        }
-      }
-    }
-    auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates);
-    pool_allocations.Set(
-        buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])));
-  }
-  return pool_allocations;
-}
-
-TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size")
-    .set_body_typed([](Array<BufferInfo> buffer_info_arr) {
-      return GreedyBySize(buffer_info_arr);
-    });
-
-}  // namespace algo
-}  // namespace usmp
-}  // namespace tir
-}  // namespace tvm
diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc
index c25578fd9779..609481ffeae9 100644
--- a/src/tir/usmp/analysis/extract_buffer_info.cc
+++ b/src/tir/usmp/analysis/extract_buffer_info.cc
@@ -26,6 +26,7 @@
  * conflicts between other tir.allocate nodes.
  */
 #include <tvm/arith/analyzer.h>
+#include <tvm/relay/executor.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/function.h>
@@ -99,11 +100,21 @@ class BufferInfoExtractor : public StmtExprVisitor {
    */
   std::unordered_map<Call, Map<tir::Stmt, Integer>, ObjectPtrHash, ObjectPtrEqual>
       buffer_info_end_stmt_idx_;
+
+  /*!
+   * \brief This structure contains information regarding a Allocate node.
+   */
+  struct AllocateInfo {
+    tir::Stmt Allocate;
+    PrimFunc prim_func;
+    Call call;
+  };
+
   /*!
    * \brief Maintains the mapping of buffer variable to their allocate nodes to ensure
    * that only one BufferInfo object is created.
    */
-  Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map_;
+  std::unordered_map<tir::Var, AllocateInfo, ObjectPtrHash, ObjectPtrEqual> allocate_infos;
   /*!
    * \brief Indicates a count of stmts visited so far to use as a metric of liveness
    */
@@ -203,29 +214,41 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) {
   auto size_bytes = CalculateExtentsSize(op);
   // We only statically memory plan only allocates with known
   // compile time sizes.
-  if (size_bytes.defined() &&
-      allocate_var_to_stmt_map_.find(op->buffer_var) == allocate_var_to_stmt_map_.end()) {
-    // By default, the core compiler is assumed to attach the a default pool to each allocate.
-    ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr))
-        << "Every statically sized allocate node needs an pool candidate attribute";
-    auto pool_candidates = Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesAllocateAttr]);
-
-    // TODO(@manupa-arm): improve the error when the responsible component for attaching a single
-    // pool is added
-    ICHECK(pool_candidates.size() > 0)
-        << "The core compiler should at least attach a single PoolInfo. If there were no "
-           "user-given arguments for memory pools, the default behaviour is a single size "
-           "un-restricted pool is assigned";
-    PrimFunc func = scope_stack_.top().func;
-    Optional<Target> tgt = func->GetAttr<Target>(tvm::attr::kTarget);
-    ICHECK(tgt) << "There should not be any PrimFuncs without a target attached by now";
-    auto workspace_alignment =
-        tgt.value()->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
-    auto buffer_info = BufferInfo(GetUniqueBufferName(op->buffer_var->name_hint), size_bytes,
-                                  pool_candidates, workspace_alignment);
-    auto allocate = GetRef<Allocate>(op);
-    allocate_var_to_stmt_map_.Set(op->buffer_var, allocate);
-    buffer_info_map_.Set(buffer_info, allocate);
+  if (size_bytes.defined()) {
+    if (allocate_infos.find(op->buffer_var) == allocate_infos.end()) {
+      // By default, the core compiler is assumed to attach the a default pool to each allocate.
+      ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr))
+          << "Every statically sized allocate node needs an pool candidate attribute";
+      auto pool_candidates =
+          Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesAllocateAttr]);
+
+      // TODO(@manupa-arm): improve the error when the responsible component for attaching a single
+      // pool is added
+      ICHECK(pool_candidates.size() > 0)
+          << "The core compiler should at least attach a single PoolInfo. If there were no "
+             "user-given arguments for memory pools, the default behaviour is a single size "
+             "un-restricted pool is assigned";
+      PrimFunc func = scope_stack_.top().func;
+      Optional<tvm::relay::Executor> executor_config =
+          module_->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor);
+      Integer workspace_alignment = 16;
+      if (executor_config) {
+        //    ICHECK(tgt) << "There should not be any PrimFuncs without a target attached by now";
+        workspace_alignment =
+            executor_config.value()->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
+      }
+      auto buffer_info = BufferInfo(GetUniqueBufferName(op->buffer_var->name_hint), size_bytes,
+                                    pool_candidates, workspace_alignment);
+      auto allocate = GetRef<Allocate>(op);
+      allocate_infos[op->buffer_var] =
+          AllocateInfo{allocate, scope_stack_.top().func, scope_stack_.top().call};
+      buffer_info_map_.Set(buffer_info, allocate);
+    } else {
+      // Update the allocate info with the latest call
+      AllocateInfo ai = allocate_infos[op->buffer_var];
+      ai.call = scope_stack_.top().call;
+      allocate_infos[op->buffer_var] = ai;
+    }
   }
 }
 
@@ -257,17 +280,25 @@ void BufferInfoExtractor::VisitStmt_(const ForNode* op) {
     si.initial_stmt_of_the_nested_loops = Integer(current_stmt_idx_);
   }
   Call current_call = scope_stack_.top().call;
+  PrimFunc current_primfunc = scope_stack_.top().func;
   scope_stack_.push(si);
   StmtExprVisitor::VisitStmt_(op);
   // Extending the liveness to beginning of for-loop next and end of the current for-loop
   for (const Allocate& allocate : scope_stack_.top().allocate_nodes) {
+    AllocateInfo ai = allocate_infos[allocate->buffer_var];
+    Call update_call = current_call;
+    // If the allocate does not belong to current prim func
+    // We need to update the call to which the allocate belong to
+    if (ai.prim_func != current_primfunc) {
+      update_call = ai.call;
+    }
     if (scope_stack_.top().initial_stmt_of_the_nested_loops->value <
-        buffer_info_start_stmt_idx_[current_call][allocate]) {
-      buffer_info_start_stmt_idx_[current_call].Set(
+        buffer_info_start_stmt_idx_[update_call][allocate]) {
+      buffer_info_start_stmt_idx_[update_call].Set(
           allocate, scope_stack_.top().initial_stmt_of_the_nested_loops->value);
     }
-    if (current_stmt_idx_ > buffer_info_end_stmt_idx_[current_call][allocate]) {
-      buffer_info_end_stmt_idx_[current_call].Set(allocate, current_stmt_idx_);
+    if (current_stmt_idx_ > buffer_info_end_stmt_idx_[update_call][allocate]) {
+      buffer_info_end_stmt_idx_[update_call].Set(allocate, current_stmt_idx_);
     }
   }
   scope_stack_.pop();
@@ -286,12 +317,21 @@ void BufferInfoExtractor::VisitStmt_(const StoreNode* op) {
 void BufferInfoExtractor::VisitExpr_(const VarNode* op) {
   auto var = GetRef<Var>(op);
   Call current_call = scope_stack_.top().call;
-  if (allocate_var_to_stmt_map_.count(var)) {
-    auto allocate = allocate_var_to_stmt_map_[var];
-    if (buffer_info_start_stmt_idx_[current_call].count(allocate) == 0) {
-      buffer_info_start_stmt_idx_[current_call].Set(allocate, current_stmt_idx_);
+  PrimFunc current_primfunc = scope_stack_.top().func;
+  if (allocate_infos.count(var)) {
+    auto allocate = allocate_infos[var].Allocate;
+    auto allocate_primfunc = allocate_infos[var].prim_func;
+    Call update_call = current_call;
+    if (allocate_primfunc != current_primfunc) {
+      // If the allocate node does not belong to the current primfunc.
+      // It's access should be reported to the call to PrimFunc that
+      // Allocate belong to.
+      update_call = allocate_infos[var].call;
     }
-    buffer_info_end_stmt_idx_[current_call].Set(allocate, current_stmt_idx_);
+    if (buffer_info_start_stmt_idx_[update_call].count(allocate) == 0) {
+      buffer_info_start_stmt_idx_[update_call].Set(allocate, current_stmt_idx_);
+    }
+    buffer_info_end_stmt_idx_[update_call].Set(allocate, current_stmt_idx_);
 
     ScopeInfo& currect_scope_info = scope_stack_.top();
     if (currect_scope_info.for_loop.defined()) {
@@ -320,13 +360,13 @@ void BufferInfoExtractor::UpdateAliases(const Array<PrimExpr>& args, const PrimF
     // to the original allocate
     if (arg->IsInstance<LoadNode>()) {
       auto load = Downcast<Load>(arg);
-      if (allocate_var_to_stmt_map_.count(load->buffer_var)) {
-        allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[load->buffer_var]);
+      if (allocate_infos.count(load->buffer_var)) {
+        allocate_infos[param_buf] = allocate_infos[load->buffer_var];
       }
     } else if (arg->IsInstance<VarNode>()) {
       auto var = Downcast<Var>(arg);
-      if (allocate_var_to_stmt_map_.count(var)) {
-        allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[var]);
+      if (allocate_infos.count(var)) {
+        allocate_infos[param_buf] = allocate_infos[var];
       }
     }
   }
@@ -415,18 +455,30 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
 
   // Traverse the liveness events using a open set to track what
   // is live while updating the conflicts through out the linear traversal
-  std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set;
+  //  std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set;
+  std::unordered_map<BufferInfo, int, ObjectPtrHash, ObjectPtrEqual> open_set;
   for (const auto& le_event : le_events_timeline) {
     if (le_event.le_type == START) {
-      for (const auto& open_buffer_info : open_set) {
+      for (const auto& kv : open_set) {
+        BufferInfo open_buffer_info = kv.first;
         open_buffer_info->conflicts.push_back(le_event.buffer_info);
         if (le_event.buffer_info != open_buffer_info) {
           le_event.buffer_info->conflicts.push_back(open_buffer_info);
         }
       }
-      open_set.insert(le_event.buffer_info);
+      //      open_set.insert(le_event.buffer_info);
+      if (open_set.find(le_event.buffer_info) == open_set.end()) {
+        open_set[le_event.buffer_info] = 1;
+      } else {
+        open_set[le_event.buffer_info] += 1;
+      }
     } else {
-      open_set.erase(le_event.buffer_info);
+      if (open_set[le_event.buffer_info] == 1) {
+        open_set.erase(le_event.buffer_info);
+      } else {
+        open_set[le_event.buffer_info] -= 1;
+      }
+      //      open_set.erase(le_event.buffer_info);
     }
   }
   return this->buffer_info_map_;
diff --git a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py b/tests/python/unittest/test_tir_usmp_algo.py
similarity index 94%
rename from tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py
rename to tests/python/unittest/test_tir_usmp_algo.py
index 6bd7832f533b..69d52009fefc 100644
--- a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py
+++ b/tests/python/unittest/test_tir_usmp_algo.py
@@ -167,7 +167,22 @@ def run_model(input: T.handle, output: T.handle) -> None:
 # fmt: on
 
 
-def test_linear():
+def print_conflicts(buffer_info_map):
+    """_verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map)"""
+
+    for buffer_info_name, buf_info in buffer_info_map.items():
+        conflict_str = "["
+        for conflict in buf_info.conflicts:
+            conflict_str += f'"{conflict.name_hint}", '
+        conflict_str += "]"
+        print(f'_verify_conflicts("{buffer_info_name}", {conflict_str}, buffer_info_map_names)')
+
+
+@pytest.mark.parametrize(
+    ["algorithm", "fast_memory_size", "slow_memory_size"],
+    [("greedy_by_size", 200704, 1418528), ("greedy_by_conflicts", 200704, 1418528)],
+)
+def test_linear(algorithm, fast_memory_size, slow_memory_size):
     target = Target("c")
     fast_memory_pool = usmp_utils.PoolInfo(
         pool_name="fast_memory",
@@ -187,7 +202,7 @@ def test_linear():
 
     fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
     buffer_info_arr = fcreate_array_bi(buffer_info_map)
-    fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
+    fusmp_algo_greedy_by_size = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
     buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr)
 
     buffer_info_map_names = dict()
@@ -203,8 +218,8 @@ def test_linear():
     )
     _verify_conflicts("Conv2dOutput_7", ["sid_8", "PaddedInput_7"], buffer_info_map_names)
 
-    _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 1418528)
-    _check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, 200704)
+    _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, slow_memory_size)
+    _check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, fast_memory_size)
 
 
 # fmt: off
@@ -328,7 +343,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place
 # fmt: on
 
 
-def test_fanout():
+@pytest.mark.parametrize(
+    ["algorithm", "workspace_size"], [("greedy_by_size", 7920256), ("greedy_by_conflicts", 7200256)]
+)
+def test_fanout(algorithm, workspace_size):
     target = Target("c")
     global_workspace_pool = usmp_utils.PoolInfo(
         pool_name="global_workspace",
@@ -342,7 +360,7 @@ def test_fanout():
 
     fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
     buffer_info_arr = fcreate_array_bi(buffer_info_map)
-    fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
+    fusmp_algo_greedy_by_size = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
     buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr)
 
     buffer_info_map_names = dict()
@@ -351,36 +369,31 @@ def test_fanout():
 
     # check conflicts
     _verify_conflicts(
-        "Conv2dOutput_1",
+        "sid_7",
         [
             "PaddedInput_1",
-            "sid_7",
-        ],
-        buffer_info_map_names,
-    )
-    _verify_conflicts(
-        "sid_8",
-        [
-            "PaddedInput",
-            "Conv2dOutput",
-            "PaddedInput_1",
+            "sid_2",
+            "Conv2dOutput_1",
+            "PaddedInput_2",
         ],
         buffer_info_map_names,
     )
     _verify_conflicts(
-        "PaddedInput_2",
+        "Conv2dOutput_3",
         [
-            "sid_7",
+            "PaddedInput_3",
             "sid_6",
-            "Conv2dOutput_2",
         ],
         buffer_info_map_names,
     )
     _verify_conflicts(
-        "sid_2",
+        "sid_6",
         [
-            "PaddedInput",
+            "Conv2dOutput_2",
+            "PaddedInput_2",
+            "sid_2",
             "PaddedInput_3",
+            "Conv2dOutput_3",
         ],
         buffer_info_map_names,
     )
@@ -388,43 +401,45 @@ def test_fanout():
         "Conv2dOutput",
         [
             "sid_8",
+            "sid_2",
             "PaddedInput",
         ],
         buffer_info_map_names,
     )
     _verify_conflicts(
-        "sid_7",
+        "PaddedInput_3",
         [
-            "Conv2dOutput_1",
-            "PaddedInput_1",
-            "PaddedInput_2",
+            "sid_6",
+            "sid_2",
+            "Conv2dOutput_3",
         ],
         buffer_info_map_names,
     )
     _verify_conflicts(
-        "sid_6",
+        "Conv2dOutput_2",
         [
             "PaddedInput_2",
-            "Conv2dOutput_2",
-            "Conv2dOutput_3",
-            "PaddedInput_3",
+            "sid_2",
+            "sid_6",
         ],
         buffer_info_map_names,
     )
     _verify_conflicts(
-        "PaddedInput_3",
+        "PaddedInput_1",
         [
+            "sid_8",
             "sid_2",
-            "Conv2dOutput_3",
-            "sid_6",
+            "sid_7",
+            "Conv2dOutput_1",
         ],
         buffer_info_map_names,
     )
     _verify_conflicts(
-        "Conv2dOutput_3",
+        "Conv2dOutput_1",
         [
-            "PaddedInput_3",
-            "sid_6",
+            "sid_7",
+            "PaddedInput_1",
+            "sid_2",
         ],
         buffer_info_map_names,
     )
@@ -438,21 +453,40 @@ def test_fanout():
         buffer_info_map_names,
     )
     _verify_conflicts(
-        "Conv2dOutput_2",
+        "sid_8",
         [
-            "sid_6",
-            "PaddedInput_2",
+            "PaddedInput",
+            "sid_2",
+            "Conv2dOutput",
+            "PaddedInput_1",
         ],
         buffer_info_map_names,
     )
     _verify_conflicts(
-        "PaddedInput_1",
+        "sid_2",
         [
+            "PaddedInput",
             "sid_8",
+            "Conv2dOutput",
+            "PaddedInput_1",
+            "sid_7",
             "Conv2dOutput_1",
+            "PaddedInput_2",
+            "Conv2dOutput_2",
+            "sid_6",
+            "PaddedInput_3",
+        ],
+        buffer_info_map_names,
+    )
+    _verify_conflicts(
+        "PaddedInput_2",
+        [
             "sid_7",
+            "sid_2",
+            "Conv2dOutput_2",
+            "sid_6",
         ],
         buffer_info_map_names,
     )
 
-    _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7200000)
+    _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py
index fa645f1379ff..abaa0cd2431e 100644
--- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py
+++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py
@@ -661,149 +661,261 @@ def test_inception_structure():
 
     # check conflicts
     _verify_conflicts(
-        "PaddedInput_8",
+        "sid_3",
+        [
+            "sid_4",
+            "PaddedInput_2",
+            "sid_2",
+            "Conv2dOutput_2",
+            "PaddedInput_1",
+            "Conv2dOutput_1",
+            "sid_20",
+            "PaddedInput_6",
+            "Conv2dOutput_6",
+            "sid_19",
+            "PaddedInput_4",
+        ],
+        buffer_info_map,
+    )
+    _verify_conflicts(
+        "Conv2dOutput",
         [
             "sid_6",
-            "Conv2dOutput_8",
-            "sid_5",
+            "PaddedInput",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_26",
+        "Conv2dOutput_7",
+        [
+            "PaddedInput_7",
+            "sid_8",
+        ],
+        buffer_info_map,
+    )
+    _verify_conflicts(
+        "sid_4",
         [
+            "sid_5",
+            "sid_3",
+            "PaddedInput_2",
+            "sid_2",
+            "Conv2dOutput_2",
+            "PaddedInput_1",
+            "Conv2dOutput_1",
+            "sid_20",
+            "PaddedInput_6",
+            "Conv2dOutput_6",
+            "sid_19",
             "PaddedInput_4",
             "Conv2dOutput_4",
+            "sid_26",
             "PaddedInput_5",
+            "Conv2dOutput_5",
+            "sid_25",
+            "tensor_3",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "Conv2dOutput",
+        "sid_2",
         [
-            "sid_6",
-            "PaddedInput",
+            "PaddedInput_2",
+            "sid_3",
+            "sid_4",
+            "Conv2dOutput_2",
+            "PaddedInput_1",
+            "Conv2dOutput_1",
+            "sid_20",
+            "PaddedInput_6",
+            "Conv2dOutput_6",
+            "sid_19",
+            "PaddedInput_4",
+            "Conv2dOutput_4",
+            "sid_26",
+            "PaddedInput_5",
+            "Conv2dOutput_5",
+            "sid_25",
+            "tensor_3",
+            "sid_32",
+            "PaddedInput_3",
+            "Conv2dOutput_3",
+            "sid_31",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_4",
+        "sid_19",
         [
-            "sid_5",
+            "Conv2dOutput_6",
+            "sid_2",
+            "PaddedInput_6",
             "sid_3",
+            "sid_4",
+            "PaddedInput_4",
+            "Conv2dOutput_4",
+            "sid_26",
+            "PaddedInput_5",
+            "Conv2dOutput_5",
+            "sid_25",
             "tensor_3",
+            "sid_32",
+            "PaddedInput_3",
+            "Conv2dOutput_3",
+            "sid_31",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "tensor_2",
+        "PaddedInput_2",
         [
-            "sid_8",
-            "sid_7",
+            "sid_3",
+            "sid_4",
+            "sid_2",
+            "Conv2dOutput_2",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "Conv2dOutput_7",
+        "Conv2dOutput_6",
         [
-            "sid_8",
-            "PaddedInput_7",
+            "sid_2",
+            "PaddedInput_6",
+            "sid_3",
+            "sid_4",
+            "sid_19",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "Conv2dOutput_1",
+        "sid_9",
         [
-            "sid_20",
-            "PaddedInput_1",
+            "PaddedInput_7",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "Conv2dOutput_4",
+        "sid_7",
         [
-            "sid_26",
-            "PaddedInput_4",
+            "tensor_2",
+            "PaddedInput",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "Conv2dOutput_2",
+        "PaddedInput_4",
         [
-            "PaddedInput_2",
             "sid_2",
+            "sid_19",
+            "sid_3",
+            "sid_4",
+            "Conv2dOutput_4",
+            "sid_26",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
         "PaddedInput_3",
         [
+            "sid_2",
             "sid_32",
-            "sid_31",
+            "sid_25",
+            "sid_19",
             "Conv2dOutput_3",
+            "sid_31",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_3",
+        "sid_5",
         [
+            "PaddedInput_8",
+            "Conv2dOutput_8",
             "sid_4",
-            "PaddedInput_2",
-            "PaddedInput_1",
-            "PaddedInput_4",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "Conv2dOutput_6",
+        "sid_31",
         [
-            "PaddedInput_6",
+            "Conv2dOutput_3",
+            "PaddedInput_3",
+            "sid_2",
+            "sid_25",
             "sid_19",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "Conv2dOutput_5",
+        "PaddedInput",
         [
-            "PaddedInput_5",
+            "sid_7",
+            "sid_6",
+            "Conv2dOutput",
+        ],
+        buffer_info_map,
+    )
+    _verify_conflicts(
+        "Conv2dOutput_2",
+        [
+            "sid_2",
+            "PaddedInput_2",
+            "sid_3",
+            "sid_4",
+        ],
+        buffer_info_map,
+    )
+    _verify_conflicts(
+        "sid_32",
+        [
+            "tensor_3",
+            "sid_2",
             "sid_25",
+            "sid_19",
+            "PaddedInput_3",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "PaddedInput_7",
+        "tensor_2",
         [
-            "sid_9",
             "sid_8",
-            "Conv2dOutput_7",
+            "sid_7",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_7",
+        "sid_26",
         [
-            "tensor_2",
-            "PaddedInput",
+            "Conv2dOutput_4",
+            "PaddedInput_4",
+            "sid_2",
+            "sid_19",
+            "sid_4",
+            "PaddedInput_5",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_31",
+        "Conv2dOutput_3",
         [
             "PaddedInput_3",
-            "Conv2dOutput_3",
-            "sid_25",
             "sid_2",
+            "sid_25",
             "sid_19",
+            "sid_31",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_5",
+        "PaddedInput_6",
         [
-            "Conv2dOutput_8",
-            "PaddedInput_8",
+            "sid_2",
+            "sid_3",
+            "sid_20",
             "sid_4",
+            "Conv2dOutput_6",
+            "sid_19",
         ],
         buffer_info_map,
     )
@@ -817,146 +929,132 @@ def test_inception_structure():
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_20",
+        "PaddedInput_8",
         [
-            "PaddedInput_1",
-            "Conv2dOutput_1",
-            "PaddedInput_6",
+            "sid_6",
+            "sid_5",
+            "Conv2dOutput_8",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "Conv2dOutput_8",
+        "Conv2dOutput_5",
         [
-            "PaddedInput_8",
-            "sid_5",
+            "PaddedInput_5",
+            "sid_2",
+            "sid_19",
+            "sid_4",
+            "sid_25",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "PaddedInput_1",
+        "Conv2dOutput_1",
         [
+            "PaddedInput_1",
+            "sid_2",
             "sid_3",
+            "sid_4",
             "sid_20",
-            "Conv2dOutput_1",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "Conv2dOutput_3",
+        "tensor_3",
         [
-            "sid_31",
-            "PaddedInput_3",
+            "sid_2",
+            "sid_25",
+            "sid_19",
+            "sid_4",
+            "sid_32",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "PaddedInput",
+        "sid_8",
         [
-            "sid_7",
-            "sid_6",
-            "Conv2dOutput",
+            "Conv2dOutput_7",
+            "PaddedInput_7",
+            "tensor_2",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "PaddedInput_2",
+        "sid_20",
         [
-            "sid_3",
-            "Conv2dOutput_2",
+            "Conv2dOutput_1",
+            "PaddedInput_1",
             "sid_2",
-        ],
-        buffer_info_map,
-    )
-    _verify_conflicts(
-        "sid_19",
-        [
-            "Conv2dOutput_6",
+            "sid_3",
+            "sid_4",
             "PaddedInput_6",
-            "sid_31",
-            "sid_2",
-            "sid_25",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "PaddedInput_4",
+        "Conv2dOutput_8",
         [
-            "sid_3",
-            "sid_26",
-            "Conv2dOutput_4",
+            "sid_5",
+            "PaddedInput_8",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "PaddedInput_5",
+        "PaddedInput_1",
         [
-            "sid_26",
-            "Conv2dOutput_5",
-            "sid_25",
+            "sid_2",
+            "sid_3",
+            "sid_4",
+            "Conv2dOutput_1",
+            "sid_20",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "PaddedInput_6",
+        "Conv2dOutput_4",
         [
-            "sid_20",
-            "Conv2dOutput_6",
+            "PaddedInput_4",
+            "sid_2",
             "sid_19",
+            "sid_4",
+            "sid_26",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
         "sid_25",
         [
-            "Conv2dOutput_5",
             "PaddedInput_5",
-            "sid_31",
+            "Conv2dOutput_5",
             "sid_2",
             "sid_19",
-        ],
-        buffer_info_map,
-    )
-    _verify_conflicts(
-        "tensor_3",
-        [
             "sid_4",
-            "sid_32",
-        ],
-        buffer_info_map,
-    )
-    _verify_conflicts(
-        "sid_32",
-        [
             "tensor_3",
+            "sid_32",
             "PaddedInput_3",
+            "Conv2dOutput_3",
+            "sid_31",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_9",
-        [
-            "PaddedInput_7",
-        ],
-        buffer_info_map,
-    )
-    _verify_conflicts(
-        "sid_2",
+        "PaddedInput_7",
         [
-            "Conv2dOutput_2",
-            "PaddedInput_2",
-            "sid_31",
-            "sid_25",
-            "sid_19",
+            "sid_9",
+            "Conv2dOutput_7",
+            "sid_8",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_8",
+        "PaddedInput_5",
         [
-            "PaddedInput_7",
-            "Conv2dOutput_7",
-            "tensor_2",
+            "sid_2",
+            "sid_19",
+            "sid_26",
+            "sid_4",
+            "Conv2dOutput_5",
+            "sid_25",
         ],
         buffer_info_map,
     )
@@ -1276,56 +1374,70 @@ def test_multiple_calls_to_same_primfunc():
 
     # check conflicts
     _verify_conflicts(
-        "sid_18",
+        "sid_6",
         [
-            "sid_19",
-            "sid_2",
-            "T_softmax_exp2",
-            "T_softmax_maxelem2",
-            "T_softmax_expsum2",
-            "T_softmax_norm2",
+            "sid_7",
+            "sid_12",
+            "compute",
+            "compute_global",
+            "sid_11",
+            "sid_10",
+            "T_softmax_exp",
+            "T_softmax_maxelem",
+            "sid_5",
+            "T_softmax_norm",
+            "T_softmax_expsum",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_3",
+        "T_softmax_exp",
         [
-            "data_pad",
-            "conv2d_NCHWc_global",
-            "sid_2",
+            "sid_10",
+            "sid_6",
+            "T_softmax_maxelem",
+            "sid_5",
+            "T_softmax_norm",
+            "T_softmax_expsum",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "T_softmax_norm",
+        "T_softmax_expsum2",
         [
-            "T_softmax_expsum",
-            "T_softmax_exp",
-            "sid_5",
-            "sid_6",
-            "T_softmax_maxelem",
-            "sid_10",
+            "T_softmax_exp2",
+            "T_softmax_norm2",
+            "sid_18",
+            "T_softmax_maxelem2",
+            "sid_2",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "T_softmax_norm2",
+        "compute",
         [
-            "T_softmax_expsum2",
-            "T_softmax_maxelem2",
-            "T_softmax_exp2",
-            "sid_18",
+            "sid_12",
+            "sid_6",
+            "compute_global",
+            "sid_11",
+            "sid_19",
+            "sid_20",
             "sid_2",
+            "compute_global",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_11",
+        "compute_global",
         [
             "compute",
             "sid_12",
-            "compute_global",
-            "sid_10",
+            "sid_6",
+            "sid_11",
+            "compute",
+            "sid_19",
+            "sid_20",
+            "sid_2",
         ],
         buffer_info_map,
     )
@@ -1334,75 +1446,93 @@ def test_multiple_calls_to_same_primfunc():
         [
             "sid_11",
             "sid_6",
+            "T_softmax_exp",
+            "T_softmax_maxelem",
             "sid_5",
             "T_softmax_norm",
             "T_softmax_expsum",
-            "T_softmax_maxelem",
-            "T_softmax_exp",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_5",
+        "sid_2",
         [
-            "T_softmax_norm",
-            "T_softmax_expsum",
-            "T_softmax_exp",
-            "sid_6",
-            "T_softmax_maxelem",
-            "sid_10",
-            "sid_4",
+            "sid_3",
+            "sid_5",
             "sid_20",
+            "sid_19",
+            "compute",
+            "compute_global",
+            "sid_18",
+            "T_softmax_norm2",
+            "T_softmax_exp2",
+            "T_softmax_maxelem2",
+            "T_softmax_expsum2",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "T_softmax_expsum",
+        "sid_5",
         [
-            "T_softmax_exp",
-            "T_softmax_norm",
-            "sid_5",
-            "sid_6",
             "T_softmax_maxelem",
             "sid_10",
+            "T_softmax_exp",
+            "sid_6",
+            "T_softmax_norm",
+            "T_softmax_expsum",
+            "sid_4",
+            "data_pad",
+            "sid_3",
+            "conv2d_NCHWc_global",
+            "sid_2",
+            "sid_20",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_8",
+        "T_softmax_norm2",
         [
-            "data_pad",
+            "sid_18",
+            "sid_2",
+            "T_softmax_exp2",
+            "T_softmax_maxelem2",
+            "T_softmax_expsum2",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "T_softmax_expsum2",
+        "sid_20",
         [
-            "T_softmax_maxelem2",
-            "T_softmax_exp2",
-            "sid_18",
             "sid_2",
-            "T_softmax_norm2",
+            "sid_5",
+            "sid_19",
+            "compute",
+            "compute_global",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "T_softmax_maxelem2",
+        "T_softmax_expsum",
         [
-            "T_softmax_exp2",
-            "sid_18",
-            "sid_2",
-            "T_softmax_expsum2",
-            "T_softmax_norm2",
+            "sid_5",
+            "T_softmax_norm",
+            "T_softmax_maxelem",
+            "sid_10",
+            "T_softmax_exp",
+            "sid_6",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_12",
+        "data_pad",
         [
-            "sid_11",
-            "compute",
-            "compute_global",
+            "sid_8",
+            "conv2d_NCHWc_global",
+            "sid_7",
+            "sid_4",
+            "sid_5",
+            "sid_3",
+            "conv2d_NCHWc_global",
         ],
         buffer_info_map,
     )
@@ -1410,6 +1540,7 @@ def test_multiple_calls_to_same_primfunc():
         "sid_19",
         [
             "sid_20",
+            "sid_2",
             "compute",
             "compute_global",
             "sid_18",
@@ -1423,17 +1554,19 @@ def test_multiple_calls_to_same_primfunc():
             "sid_7",
             "sid_3",
             "data_pad",
+            "sid_5",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "T_softmax_exp2",
+        "sid_18",
         [
-            "sid_18",
+            "sid_19",
             "sid_2",
+            "T_softmax_norm2",
+            "T_softmax_exp2",
             "T_softmax_maxelem2",
             "T_softmax_expsum2",
-            "T_softmax_norm2",
         ],
         buffer_info_map,
     )
@@ -1447,105 +1580,84 @@ def test_multiple_calls_to_same_primfunc():
         buffer_info_map,
     )
     _verify_conflicts(
-        "data_pad",
+        "T_softmax_exp2",
         [
-            "sid_8",
-            "conv2d_NCHWc_global",
-            "sid_7",
-            "sid_4",
-            "sid_3",
-            "conv2d_NCHWc_global",
+            "T_softmax_norm2",
+            "sid_18",
+            "sid_2",
+            "T_softmax_maxelem2",
+            "T_softmax_expsum2",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_20",
+        "sid_4",
         [
             "sid_5",
-            "sid_19",
-            "compute",
-            "compute_global",
+            "data_pad",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_4",
+        "T_softmax_maxelem",
         [
+            "sid_10",
+            "T_softmax_exp",
+            "sid_6",
             "sid_5",
-            "data_pad",
+            "T_softmax_norm",
+            "T_softmax_expsum",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "T_softmax_exp",
+        "T_softmax_maxelem2",
         [
-            "T_softmax_expsum",
-            "T_softmax_norm",
-            "sid_5",
-            "sid_6",
-            "T_softmax_maxelem",
-            "sid_10",
+            "T_softmax_exp2",
+            "T_softmax_norm2",
+            "sid_18",
+            "sid_2",
+            "T_softmax_expsum2",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "compute_global",
+        "sid_11",
         [
-            "sid_12",
-            "sid_11",
-            "compute",
             "compute",
-            "sid_20",
-            "sid_19",
+            "sid_12",
+            "compute_global",
+            "sid_6",
+            "sid_10",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "compute",
+        "sid_12",
         [
-            "sid_11",
-            "sid_12",
-            "compute_global",
-            "sid_20",
-            "sid_19",
+            "sid_6",
+            "compute",
             "compute_global",
+            "sid_11",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_6",
+        "T_softmax_norm",
         [
-            "sid_7",
             "sid_5",
-            "T_softmax_norm",
-            "T_softmax_expsum",
-            "T_softmax_exp",
             "T_softmax_maxelem",
             "sid_10",
-        ],
-        buffer_info_map,
-    )
-    _verify_conflicts(
-        "T_softmax_maxelem",
-        [
+            "T_softmax_exp",
             "sid_6",
-            "sid_5",
-            "T_softmax_norm",
             "T_softmax_expsum",
-            "T_softmax_exp",
-            "sid_10",
         ],
         buffer_info_map,
     )
     _verify_conflicts(
-        "sid_2",
+        "sid_8",
         [
-            "sid_3",
-            "sid_18",
-            "T_softmax_exp2",
-            "T_softmax_maxelem2",
-            "T_softmax_expsum2",
-            "T_softmax_norm2",
+            "data_pad",
         ],
         buffer_info_map,
     )

From 1d5783214bce107d4f4abfd52483463fe6bc5c98 Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne@arm.com>
Date: Mon, 29 Nov 2021 19:10:26 +0000
Subject: [PATCH 4/8] Removing unimplemented Python-APi for USMP

There was some remanants of unimplemented python
APIs related BufferInfo. They are removed now.

Change-Id: I4de1d817eb34187bc20da2ac2b1cb0da5b372833
---
 python/tvm/tir/usmp/utils.py | 12 ------------
 1 file changed, 12 deletions(-)

diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py
index 0445775869e8..188d4c57810b 100644
--- a/python/tvm/tir/usmp/utils.py
+++ b/python/tvm/tir/usmp/utils.py
@@ -114,18 +114,6 @@ def __init__(
             alignment,
         )
 
-    def set_pool_candidates(self, pool_candidates: list):
-        """Sets the pool candidate names"""
-        _ffi_api.BufferInfoSetPoolCandidates(self, pool_candidates)
-
-    def set_pool_offsets(self, pool_name: str, pool_offset: int):
-        """Sets the pool offset by name"""
-        _ffi_api.BufferInfoSetPoolOffset(self, pool_name, pool_offset)
-
-    def set_conflicts(self, conflicts: list):
-        """Sets the the conflicting array of buffer info objects"""
-        _ffi_api.BufferInfoSetConflicts(self, conflicts)
-
 
 @register_object("tir.usmp.PoolAllocation")
 class PoolAllocation(Object):

From 78e099b6e35cb54cbf9eafd97096e136689d1a43 Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne@arm.com>
Date: Mon, 29 Nov 2021 19:15:32 +0000
Subject: [PATCH 5/8] Removing commented code from extract buffer info pass

Change-Id: Ia2d24753398bb388918aecba3b0191100d5100a6
---
 src/tir/usmp/analysis/extract_buffer_info.cc | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc
index 609481ffeae9..f82db06685f4 100644
--- a/src/tir/usmp/analysis/extract_buffer_info.cc
+++ b/src/tir/usmp/analysis/extract_buffer_info.cc
@@ -233,7 +233,6 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) {
           module_->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor);
       Integer workspace_alignment = 16;
       if (executor_config) {
-        //    ICHECK(tgt) << "There should not be any PrimFuncs without a target attached by now";
         workspace_alignment =
             executor_config.value()->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
       }

From 7c93c60c9a383ff74934eff63ae79a9d02fe8357 Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne@arm.com>
Date: Tue, 30 Nov 2021 19:55:31 +0000
Subject: [PATCH 6/8] [TIR][USMP] Greedy algorithms for USMP

This commits removes commented out lines
,few trivial cleanups and few BufferInfo
based tests to check the algorithm.

Change-Id: I1a12b6a424370e9e4c4a55563dde0ad698b07ea3
---
 python/tvm/tir/usmp/utils.py                 |   4 +
 src/tir/usmp/algo/greedy.cc                  |  10 +-
 src/tir/usmp/analysis/extract_buffer_info.cc |   3 -
 tests/python/unittest/test_tir_usmp_algo.py  | 232 +++++++++++++++++--
 4 files changed, 217 insertions(+), 32 deletions(-)

diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py
index 188d4c57810b..470765174acb 100644
--- a/python/tvm/tir/usmp/utils.py
+++ b/python/tvm/tir/usmp/utils.py
@@ -114,6 +114,10 @@ def __init__(
             alignment,
         )
 
+    def set_conflicts(self, conflicts: list):
+        """Sets the the conflicting array of buffer info objects"""
+        _ffi_api.BufferInfoSetConflicts(self, conflicts)
+
 
 @register_object("tir.usmp.PoolAllocation")
 class PoolAllocation(Object):
diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc
index f0b1581cd616..78afe0333c99 100644
--- a/src/tir/usmp/algo/greedy.cc
+++ b/src/tir/usmp/algo/greedy.cc
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file tir/analysis/usmp/algo/greedy_by_size.cc
+ * \file tir/analysis/usmp/algo/greedy.cc
  * \brief This source contains greedy algorithms for planning
  * memory for USMP. There are two algorithms present here :
  * 1) greedy_by_size and 2) greedy_by_conflicts.
@@ -89,17 +89,17 @@ class GreedyBase {
    * \brief Selects a pool for placement in the given set of ordered pool candidates
    */
   PoolInfo SelectPlacementPool(
-      const Array<PoolInfo>& pool_candidates,
+      const BufferInfo& buf_info,
       const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
     // Here the pool candidates are ordered when it is consumed by the algorithm.
     // This could be from order the user has specified. However, schedulers are
     // welcome to change the order for performance reasons.
-    for (const auto& pool_info : pool_candidates) {
+    for (const auto& pool_info : buf_info->pool_candidates) {
       if (pool_offsets.count(pool_info)) {
         return pool_info;
       }
     }
-    ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!";
+    CHECK(false) << "TVM USMP Error: no candidate have been selected for " << buf_info;
     return PoolInfo();
   }
 
@@ -141,7 +141,7 @@ class GreedyBase {
           }
         }
       }
-      auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates);
+      auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates);
       pool_allocations.Set(
           buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])));
     }
diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc
index f82db06685f4..3fea7211672f 100644
--- a/src/tir/usmp/analysis/extract_buffer_info.cc
+++ b/src/tir/usmp/analysis/extract_buffer_info.cc
@@ -454,7 +454,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
 
   // Traverse the liveness events using a open set to track what
   // is live while updating the conflicts through out the linear traversal
-  //  std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set;
   std::unordered_map<BufferInfo, int, ObjectPtrHash, ObjectPtrEqual> open_set;
   for (const auto& le_event : le_events_timeline) {
     if (le_event.le_type == START) {
@@ -465,7 +464,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
           le_event.buffer_info->conflicts.push_back(open_buffer_info);
         }
       }
-      //      open_set.insert(le_event.buffer_info);
       if (open_set.find(le_event.buffer_info) == open_set.end()) {
         open_set[le_event.buffer_info] = 1;
       } else {
@@ -477,7 +475,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
       } else {
         open_set[le_event.buffer_info] -= 1;
       }
-      //      open_set.erase(le_event.buffer_info);
     }
   }
   return this->buffer_info_map_;
diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py
index 69d52009fefc..4efb88affd16 100644
--- a/tests/python/unittest/test_tir_usmp_algo.py
+++ b/tests/python/unittest/test_tir_usmp_algo.py
@@ -51,7 +51,7 @@ def get_allocate(stmt):
     return allocates
 
 
-def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
+def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
     """helper to assing poolinfos to allocate nodes in a tir.PrimFunc"""
 
     def set_poolinfos(stmt):
@@ -68,12 +68,12 @@ def set_poolinfos(stmt):
     return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos))
 
 
-def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
+def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
     """helper to assing poolinfos to allocate nodes in a IRModule"""
     ret = tvm.IRModule()
     for global_var, basefunc in mod.functions.items():
         if isinstance(basefunc, tvm.tir.PrimFunc):
-            ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos)
+            ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos)
     return ret
 
 
@@ -96,9 +96,204 @@ def _check_max_workspace_size(buffer_pool_allocations, pool_info, size):
     assert max_workspace_size == size
 
 
+def test_no_pool_error():
+    target = Target("c")
+    tiny_workspace_pool = usmp_utils.PoolInfo(
+        pool_name="tiny_workspace",
+        target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+        size_hint_bytes=10,
+    )
+    bi_a = usmp_utils.BufferInfo(
+        name_hint="bi_a", size_bytes=10, pool_candidates=[tiny_workspace_pool]
+    )
+    bi_b = usmp_utils.BufferInfo(
+        name_hint="bi_b", size_bytes=10, pool_candidates=[tiny_workspace_pool]
+    )
+    bi_c = usmp_utils.BufferInfo(
+        name_hint="bi_c", size_bytes=10, pool_candidates=[tiny_workspace_pool]
+    )
+    bi_a.set_conflicts([bi_b])
+    bi_b.set_conflicts([bi_c])
+    bi_c.set_conflicts([bi_a])
+    buffer_info_arr = [bi_a, bi_b, bi_c]
+    fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.greedy_by_size")
+    with pytest.raises(
+        tvm.TVMError, match="TVM USMP Error: no candidate have been selected for BufferInfoNode"
+    ):
+        buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+
+
+@pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts"])
+def test_name_based_ordering(algorithm):
+    """ This checks when the size and conlicts are same a stable result is generated"""
+
+    def _test():
+        target = Target("c")
+        global_workspace_pool = usmp_utils.PoolInfo(
+            pool_name="global_workspace",
+            target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+        )
+        bi_a = usmp_utils.BufferInfo(
+            name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool]
+        )
+        bi_b = usmp_utils.BufferInfo(
+            name_hint="bi_b", size_bytes=10, pool_candidates=[global_workspace_pool]
+        )
+        bi_c = usmp_utils.BufferInfo(
+            name_hint="bi_c", size_bytes=10, pool_candidates=[global_workspace_pool]
+        )
+        bi_a.set_conflicts([bi_b])
+        bi_b.set_conflicts([bi_c])
+        bi_c.set_conflicts([bi_a])
+
+        buffer_info_arr = [bi_a, bi_b, bi_c]
+        fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
+        buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+        assert buffer_pool_allocations[bi_a].byte_offset == 0
+        assert buffer_pool_allocations[bi_b].byte_offset == 20
+        assert buffer_pool_allocations[bi_c].byte_offset == 10
+
+    # This is tested for several times to check stability
+    for x in range(0, 10):
+        _test()
+
+
+@pytest.mark.parametrize(
+    ["algorithm", "workspace_size"],
+    [("greedy_by_size", 140), ("greedy_by_conflicts", 140)],
+)
+def test_linear(algorithm, workspace_size):
+    """
+    The test case here represent BufferInfo objects
+    that could get generated for a linear sequence
+    such as :
+    (Op A)
+    |
+    bi_a
+    |
+    (Op B)
+    |
+    bi_b
+    |
+    .
+    .
+    .
+    (Op F)
+    |
+    bi_f
+    """
+    target = Target("c")
+    global_workspace_pool = usmp_utils.PoolInfo(
+        pool_name="global_workspace",
+        target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+    )
+    bi_a = usmp_utils.BufferInfo(
+        name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool]
+    )
+    bi_b = usmp_utils.BufferInfo(
+        name_hint="bi_b", size_bytes=20, pool_candidates=[global_workspace_pool]
+    )
+    bi_c = usmp_utils.BufferInfo(
+        name_hint="bi_c", size_bytes=100, pool_candidates=[global_workspace_pool]
+    )
+    bi_d = usmp_utils.BufferInfo(
+        name_hint="bi_d", size_bytes=40, pool_candidates=[global_workspace_pool]
+    )
+    bi_e = usmp_utils.BufferInfo(
+        name_hint="bi_e", size_bytes=50, pool_candidates=[global_workspace_pool]
+    )
+    bi_f = usmp_utils.BufferInfo(
+        name_hint="bi_f", size_bytes=50, pool_candidates=[global_workspace_pool]
+    )
+
+    # Creating conflicts for a linear graph
+    bi_a.set_conflicts([bi_b])
+    bi_b.set_conflicts([bi_a, bi_c])
+    bi_c.set_conflicts([bi_b, bi_d])
+    bi_d.set_conflicts([bi_c, bi_e])
+    bi_e.set_conflicts([bi_d, bi_f])
+    bi_f.set_conflicts([bi_e])
+
+    buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f]
+    fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
+    buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+    _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
+
+
+@pytest.mark.parametrize(
+    ["algorithm", "workspace_size"],
+    [("greedy_by_size", 190), ("greedy_by_conflicts", 320)],
+)
+def test_fanout(algorithm, workspace_size):
+    """
+    The test case here represent BufferInfo objects
+    that could get generated for a fanout topology
+    such as :
+    (Op A)
+    |
+    bi_a ---------
+    |            |
+    (Op B)     (Op C)
+    |            |
+    bi_b        bi_c
+    |            |
+    (Op D)     (Op E)
+    |            |
+    bi_d        bi_e
+    |            |
+    (Op F) ------
+    |
+    bi_f
+    |
+    (Op G)
+    |
+    bi_g
+    """
+    target = Target("c")
+    global_workspace_pool = usmp_utils.PoolInfo(
+        pool_name="global_workspace",
+        target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+    )
+    bi_a = usmp_utils.BufferInfo(
+        name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool]
+    )
+    bi_b = usmp_utils.BufferInfo(
+        name_hint="bi_b", size_bytes=20, pool_candidates=[global_workspace_pool]
+    )
+    bi_c = usmp_utils.BufferInfo(
+        name_hint="bi_c", size_bytes=100, pool_candidates=[global_workspace_pool]
+    )
+    bi_d = usmp_utils.BufferInfo(
+        name_hint="bi_d", size_bytes=40, pool_candidates=[global_workspace_pool]
+    )
+    bi_e = usmp_utils.BufferInfo(
+        name_hint="bi_e", size_bytes=50, pool_candidates=[global_workspace_pool]
+    )
+    bi_f = usmp_utils.BufferInfo(
+        name_hint="bi_f", size_bytes=60, pool_candidates=[global_workspace_pool]
+    )
+    bi_g = usmp_utils.BufferInfo(
+        name_hint="bi_g", size_bytes=70, pool_candidates=[global_workspace_pool]
+    )
+
+    # Creating conflicts for a linear graph
+    bi_a.set_conflicts([bi_b, bi_c])
+    bi_b.set_conflicts([bi_a, bi_c, bi_e])
+    bi_c.set_conflicts([bi_e, bi_a, bi_b, bi_d])
+    bi_d.set_conflicts([bi_b, bi_f, bi_c, bi_e])
+    bi_e.set_conflicts([bi_c, bi_f, bi_b, bi_d])
+    bi_f.set_conflicts([bi_d, bi_e, bi_f])
+    bi_g.set_conflicts([bi_f])
+
+    buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f, bi_g]
+    fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
+    buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+    _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
+
+
 # fmt: off
 @tvm.script.ir_module
-class LinearStructure:
+class MobilenetStructure:
     @T.prim_func
     def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None:
         # function attr dict
@@ -167,22 +362,11 @@ def run_model(input: T.handle, output: T.handle) -> None:
 # fmt: on
 
 
-def print_conflicts(buffer_info_map):
-    """_verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map)"""
-
-    for buffer_info_name, buf_info in buffer_info_map.items():
-        conflict_str = "["
-        for conflict in buf_info.conflicts:
-            conflict_str += f'"{conflict.name_hint}", '
-        conflict_str += "]"
-        print(f'_verify_conflicts("{buffer_info_name}", {conflict_str}, buffer_info_map_names)')
-
-
 @pytest.mark.parametrize(
     ["algorithm", "fast_memory_size", "slow_memory_size"],
     [("greedy_by_size", 200704, 1418528), ("greedy_by_conflicts", 200704, 1418528)],
 )
-def test_linear(algorithm, fast_memory_size, slow_memory_size):
+def test_mobilenet_subgraph(algorithm, fast_memory_size, slow_memory_size):
     target = Target("c")
     fast_memory_pool = usmp_utils.PoolInfo(
         pool_name="fast_memory",
@@ -192,9 +376,9 @@ def test_linear(algorithm, fast_memory_size, slow_memory_size):
     slow_memory_pool = usmp_utils.PoolInfo(
         pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}
     )
-    tir_mod = LinearStructure
+    tir_mod = MobilenetStructure
     tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
-    tir_mod = assign_poolinfos_to_allocates_in_irmodule(
+    tir_mod = _assign_poolinfos_to_allocates_in_irmodule(
         tir_mod, [fast_memory_pool, slow_memory_pool]
     )
     main_func = tir_mod["run_model"]
@@ -202,8 +386,8 @@ def test_linear(algorithm, fast_memory_size, slow_memory_size):
 
     fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
     buffer_info_arr = fcreate_array_bi(buffer_info_map)
-    fusmp_algo_greedy_by_size = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
-    buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr)
+    fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
+    buffer_pool_allocations = fusmp_algo(buffer_info_arr)
 
     buffer_info_map_names = dict()
     for buf_info in buffer_info_arr:
@@ -346,7 +530,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place
 @pytest.mark.parametrize(
     ["algorithm", "workspace_size"], [("greedy_by_size", 7920256), ("greedy_by_conflicts", 7200256)]
 )
-def test_fanout(algorithm, workspace_size):
+def test_resnet_subgraph(algorithm, workspace_size):
     target = Target("c")
     global_workspace_pool = usmp_utils.PoolInfo(
         pool_name="global_workspace",
@@ -354,14 +538,14 @@ def test_fanout(algorithm, workspace_size):
     )
     tir_mod = ResnetStructure
     tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
-    tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
+    tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
     main_func = tir_mod["tvmgen_default_run_model"]
     buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
 
     fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
     buffer_info_arr = fcreate_array_bi(buffer_info_map)
-    fusmp_algo_greedy_by_size = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
-    buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr)
+    fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
+    buffer_pool_allocations = fusmp_algo(buffer_info_arr)
 
     buffer_info_map_names = dict()
     for buf_info in buffer_info_arr:

From 6101e61017619ee52f9d582de101a491d7172d5e Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne@arm.com>
Date: Wed, 1 Dec 2021 01:54:06 +0000
Subject: [PATCH 7/8] [TIR][USMP] Greedy algorithms for USMP

Changed sorting criteria use alphabetic
ordering as opposed to hashes of string
as it seemed different accross different
platforms.

Change-Id: Ia7938d1b0d1374924c3ec7287526ccf374c54eb7
---
 src/tir/usmp/algo/greedy.cc                 | 8 ++------
 tests/python/unittest/test_tir_usmp_algo.py | 6 +++---
 2 files changed, 5 insertions(+), 9 deletions(-)

diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc
index 78afe0333c99..8d65e0a10ef0 100644
--- a/src/tir/usmp/algo/greedy.cc
+++ b/src/tir/usmp/algo/greedy.cc
@@ -167,9 +167,7 @@ class GreedySize : public GreedyBase {
               [](const BufferInfo& a, const BufferInfo& b) {
                 if (a->size_bytes->value == b->size_bytes->value) {
                   if (a->conflicts.size() == b->conflicts.size()) {
-                    auto a_name_hash = std::hash<std::string>{}(a->name_hint->data);
-                    auto b_name_hash = std::hash<std::string>{}(b->name_hint->data);
-                    return a_name_hash > b_name_hash;
+                    return std::string(a->name_hint->data) > std::string(b->name_hint->data);
                   } else {
                     return a->conflicts.size() > b->conflicts.size();
                   }
@@ -198,9 +196,7 @@ class GreedyConflicts : public GreedyBase {
               [](const BufferInfo& a, const BufferInfo& b) {
                 if (a->conflicts.size() == b->conflicts.size()) {
                   if (a->size_bytes->value == b->size_bytes->value) {
-                    auto a_name_hash = std::hash<std::string>{}(a->name_hint->data);
-                    auto b_name_hash = std::hash<std::string>{}(b->name_hint->data);
-                    return a_name_hash > b_name_hash;
+                    return std::string(a->name_hint->data) > std::string(b->name_hint->data);
                   } else {
                     return a->size_bytes->value > b->size_bytes->value;
                   }
diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py
index 4efb88affd16..32cd30fb5bed 100644
--- a/tests/python/unittest/test_tir_usmp_algo.py
+++ b/tests/python/unittest/test_tir_usmp_algo.py
@@ -149,9 +149,9 @@ def _test():
         buffer_info_arr = [bi_a, bi_b, bi_c]
         fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
         buffer_pool_allocations = fusmp_algo(buffer_info_arr)
-        assert buffer_pool_allocations[bi_a].byte_offset == 0
-        assert buffer_pool_allocations[bi_b].byte_offset == 20
-        assert buffer_pool_allocations[bi_c].byte_offset == 10
+        assert buffer_pool_allocations[bi_a].byte_offset == 20
+        assert buffer_pool_allocations[bi_b].byte_offset == 10
+        assert buffer_pool_allocations[bi_c].byte_offset == 0
 
     # This is tested for several times to check stability
     for x in range(0, 10):

From 87b0ac9cda13170fd634c09492a0c3e3fbf0f35e Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne@arm.com>
Date: Wed, 1 Dec 2021 15:08:39 +0000
Subject: [PATCH 8/8] [TIR][USMP] Greedy algorithms for USMP

Improving the error message

Change-Id: Ib59efb172fe10b70f88a24f4874a7891e8a9cde7
---
 src/tir/usmp/algo/greedy.cc                 | 4 +++-
 tests/python/unittest/test_tir_usmp_algo.py | 2 +-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc
index 8d65e0a10ef0..b98d828ae745 100644
--- a/src/tir/usmp/algo/greedy.cc
+++ b/src/tir/usmp/algo/greedy.cc
@@ -99,7 +99,9 @@ class GreedyBase {
         return pool_info;
       }
     }
-    CHECK(false) << "TVM USMP Error: no candidate have been selected for " << buf_info;
+    CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when "
+                    "trying to allocate the buffer : "
+                 << buf_info << "\n. Please increase the size_hints for memory pools.";
     return PoolInfo();
   }
 
diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py
index 32cd30fb5bed..61a70ad062d1 100644
--- a/tests/python/unittest/test_tir_usmp_algo.py
+++ b/tests/python/unittest/test_tir_usmp_algo.py
@@ -118,7 +118,7 @@ def test_no_pool_error():
     buffer_info_arr = [bi_a, bi_b, bi_c]
     fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.greedy_by_size")
     with pytest.raises(
-        tvm.TVMError, match="TVM USMP Error: no candidate have been selected for BufferInfoNode"
+        tvm.TVMError, match="TVM USMP Error: the space available in the provided pools exceeded"
     ):
         buffer_pool_allocations = fusmp_algo(buffer_info_arr)