From a66ab1eee6538145307698bd0f8542e49a56cc75 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 29 Aug 2021 00:46:19 +0800 Subject: [PATCH 1/4] [TIR] GetBlockReadWriteRegion --- include/tvm/tir/analysis.h | 19 ++++++++--- python/tvm/tir/analysis/analysis.py | 24 ++++++++++++- .../analysis/block_access_region_detector.cc | 34 ++++++++++++++++++- src/tir/schedule/primitive/compute_inline.cc | 2 +- .../plan_update_buffer_allocation_location.cc | 16 +++------ ...st_tir_analysis_get_block_access_region.py | 31 +++++++++++++++++ .../unittest/test_tir_schedule_reduction.py | 1 - 7 files changed, 107 insertions(+), 20 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index dce9736adec7..e89d87df616c 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -156,8 +156,8 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); /*! - * \brief Auto detect the block read/write region according to body stmt - * It will detect the read/write region as an array in order of appearance in AST + * \brief Auto detect the block access region according to its body stmt + * It will detect the access region as an array in order of appearance in AST * \param block The block to be detected * \param buffer_var_map The outside buffers which may be accessed the block. * It is a map from buffer var to the buffer. @@ -167,8 +167,19 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constrain * - second: write regions * - third: opaque regions */ -Array> GetBlockAccessRegion(const Block& block, - const Map& buffer_var_map); +TVM_DLL Array> GetBlockAccessRegion(const Block& block, + const Map& buffer_var_map); + +/*! + * \brief Auto detect the block read/write region according to its body stmt. An opaque access will + * be counted as both a read and a write access + * \param block The block to be detected + * \param buffer_var_map The outside buffers which may be accessed the block. + * It is a map from buffer var to the buffer + * \return An array only consisting of the read regions and write regions of the input block + */ +TVM_DLL Array> GetBlockReadWriteRegion(Block block, + Map buffer_var_map); /*! * \brief Calculate the expresion complexity based on number of symbols it contains. diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 500195ac9a13..d1aaa61c3aae 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -136,7 +136,29 @@ def get_block_access_region( - second: write regions - third: opaque regions """ - return _ffi_api.get_block_access_region(block, buffer_var_map) # type: ignore + return _ffi_api.GetBlockAccessRegion(block, buffer_var_map) # type: ignore + + +def get_block_read_write_region( + block: Block, buffer_var_map: Dict[Var, Buffer] +) -> List[List[BufferRegion]]: + """Auto detect the block read/write region according to its body stmt. + An opaque access will be counted as both a read and a write access + + Parameters + ---------- + block: tvm.tir.Block + The block in which we are detecting read/write regions. + + buffer_var_map : Dict[Var, Buffer] + The outside buffers which may access the block. Mapping from buffer var to the buffer + + Returns + ------- + result : List[List[BufferRegion]] + An array only consisting of the read regions and write regions of the input block + """ + return _ffi_api.GetBlockReadWriteRegion(block, buffer_var_map) # type: ignore def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int) -> int: diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index dd01aed61c52..93f3712436e0 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -285,7 +285,39 @@ Array> GetBlockAccessRegion(const Block& block, return {detector.CollectReads(), detector.CollectWrites(), detector.CollectOpaques()}; } -TVM_REGISTER_GLOBAL("tir.analysis.get_block_access_region").set_body_typed(GetBlockAccessRegion); +Array> GetBlockReadWriteRegion(Block block, Map buffer_var_map) { + // Step 1. Get all the read/write/opaque accesses in the input block. + Array> access_regions = + GetBlockAccessRegion(std::move(block), std::move(buffer_var_map)); + // Step 2. Collect all the buffers that are opaquely accessed. + std::unordered_set opaque_accessed_buffers; + for (const BufferRegion& opaque_access : access_regions[2]) { + opaque_accessed_buffers.insert(opaque_access->buffer.get()); + } + // Step 3. Create new arrays of read/write regions. + Array new_read_regions; + Array new_write_regions; + new_read_regions.reserve(access_regions[0].size() + access_regions[2].size()); + new_write_regions.reserve(access_regions[1].size() + access_regions[2].size()); + for (const BufferRegion& read_access : access_regions[0]) { + if (!opaque_accessed_buffers.count(read_access->buffer.get())) { + new_read_regions.push_back(read_access); + } + } + for (const BufferRegion& write_access : access_regions[1]) { + if (!opaque_accessed_buffers.count(write_access->buffer.get())) { + new_write_regions.push_back(write_access); + } + } + for (const BufferRegion& opaque_access : access_regions[2]) { + new_read_regions.push_back(opaque_access); + new_write_regions.push_back(opaque_access); + } + return {new_read_regions, new_write_regions}; +} + +TVM_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion); +TVM_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion").set_body_typed(GetBlockReadWriteRegion); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 2583b21227e4..9c88cc1e787a 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -409,7 +409,7 @@ class BaseInliner : public StmtExprMutator { Array reads = std::move(block->reads); Array writes = std::move(block->writes); if (!is_scope_root) { - Array> inspected = GetBlockAccessRegion(block, buffer_var_map_); + Array> inspected = GetBlockReadWriteRegion(block, buffer_var_map_); reads = std::move(inspected[0]); writes = std::move(inspected[1]); } diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 59f9170786b6..97153aedc6a3 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -129,7 +129,10 @@ class BufferAllocationLocator : public StmtExprMutator { /*init=*/NullOpt, /*alloc_buffers=*/alloc_buffers); ObjectPtr n = CopyOnWrite(opaque_block.get()); - CollectReadWrite(opaque_block, &n->reads, &n->writes); + Array> access = + GetBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); + n->reads = access[0]; + n->writes = access[1]; BlockRealize realize({}, Bool(true), Block(n)); return std::move(realize); } @@ -144,17 +147,6 @@ class BufferAllocationLocator : public StmtExprMutator { return result; } - void CollectReadWrite(const Block& block, Array* reads, - Array* writes) const { - Array> access = GetBlockAccessRegion(block, buffer_data_to_buffer_); - *reads = access[0]; - *writes = access[1]; - for (const auto& opaque_access : access[2]) { - reads->push_back(opaque_access); - writes->push_back(opaque_access); - } - } - /*! \brief The map from stmt to the buffers to be allocated under it. */ std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 9c95b9819e6f..802d3eae9dc2 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import tir, script from tvm.ir import Range @@ -81,6 +82,22 @@ def opaque_block_func() -> None: B[i, j] = A[i, j] + 1.0 +@tvm.script.tir +def opaque_access_func() -> None: + A = tir.alloc_buffer([1024]) + B = tir.alloc_buffer([1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [v]: + tir.bind(v, i) + tir.reads([A[v * 128 : v * 128 + 128]]) + tir.writes([B[v * 128 : v * 128 + 128]]) + tir.evaluate( + tir.call_extern( + "test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32" + ) + ) + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -110,6 +127,19 @@ def test_opaque_block(): tvm.ir.assert_structural_equal(block1.writes, ret[1]) +def test_opaque_access(): + block = opaque_access_func.body.block.body.body.block + alloc_buffers = opaque_access_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(ret0[0], ret1[0]) + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(ret0[1], ret1[1]) + + def test_match_buffer(): root_block = match_buffer_func.body.block block = root_block.body.body.body.block @@ -141,4 +171,5 @@ def test_match_buffer(): if __name__ == "__main__": test_block_access_region_detector() test_opaque_block() + test_opaque_access() test_match_buffer() diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index 067952899c0a..bc054938d282 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -17,7 +17,6 @@ # pylint: disable=missing-function-docstring,missing-module-docstring import sys -import numpy as np import pytest import tvm import tvm.testing From 8797022b8fe70ff44a1a64a74503ff5105939168 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 29 Aug 2021 01:16:49 +0800 Subject: [PATCH 2/4] Fix black issue --- .../unittest/test_tir_analysis_get_block_access_region.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 802d3eae9dc2..bc421aa4d19b 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -92,9 +92,7 @@ def opaque_access_func() -> None: tir.reads([A[v * 128 : v * 128 + 128]]) tir.writes([B[v * 128 : v * 128 + 128]]) tir.evaluate( - tir.call_extern( - "test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32" - ) + tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32") ) From a897f7b526c6cd652f7312d796e00e00a67ad2fa Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 29 Aug 2021 15:14:13 +0800 Subject: [PATCH 3/4] Use constant reference for the interface --- include/tvm/tir/analysis.h | 5 +++-- src/tir/analysis/block_access_region_detector.cc | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index e89d87df616c..b463d29c1499 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -178,8 +178,8 @@ TVM_DLL Array> GetBlockAccessRegion(const Block& block, * It is a map from buffer var to the buffer * \return An array only consisting of the read regions and write regions of the input block */ -TVM_DLL Array> GetBlockReadWriteRegion(Block block, - Map buffer_var_map); +TVM_DLL Array> GetBlockReadWriteRegion(const Block& block, + const Map& buffer_var_map); /*! * \brief Calculate the expresion complexity based on number of symbols it contains. @@ -243,3 +243,4 @@ TVM_DLL Pass VerifyGPUCode(Map constraints); } // namespace tir } // namespace tvm #endif // TVM_TIR_ANALYSIS_H_ + diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 93f3712436e0..99edb549d142 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -285,10 +285,11 @@ Array> GetBlockAccessRegion(const Block& block, return {detector.CollectReads(), detector.CollectWrites(), detector.CollectOpaques()}; } -Array> GetBlockReadWriteRegion(Block block, Map buffer_var_map) { +Array> GetBlockReadWriteRegion(const Block& block, + const Map& buffer_var_map) { // Step 1. Get all the read/write/opaque accesses in the input block. Array> access_regions = - GetBlockAccessRegion(std::move(block), std::move(buffer_var_map)); + GetBlockAccessRegion(block, buffer_var_map); // Step 2. Collect all the buffers that are opaquely accessed. std::unordered_set opaque_accessed_buffers; for (const BufferRegion& opaque_access : access_regions[2]) { From 1e8943db1de5c7629a3542941419757992430bb9 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 29 Aug 2021 15:20:28 +0800 Subject: [PATCH 4/4] Fix lint issue --- include/tvm/tir/analysis.h | 1 - src/tir/analysis/block_access_region_detector.cc | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index b463d29c1499..51bdb18d2217 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -243,4 +243,3 @@ TVM_DLL Pass VerifyGPUCode(Map constraints); } // namespace tir } // namespace tvm #endif // TVM_TIR_ANALYSIS_H_ - diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 99edb549d142..90aaa35d60d8 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -288,8 +288,7 @@ Array> GetBlockAccessRegion(const Block& block, Array> GetBlockReadWriteRegion(const Block& block, const Map& buffer_var_map) { // Step 1. Get all the read/write/opaque accesses in the input block. - Array> access_regions = - GetBlockAccessRegion(block, buffer_var_map); + Array> access_regions = GetBlockAccessRegion(block, buffer_var_map); // Step 2. Collect all the buffers that are opaquely accessed. std::unordered_set opaque_accessed_buffers; for (const BufferRegion& opaque_access : access_regions[2]) {