diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index dce9736adec7..51bdb18d2217 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -156,8 +156,8 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); /*! - * \brief Auto detect the block read/write region according to body stmt - * It will detect the read/write region as an array in order of appearance in AST + * \brief Auto detect the block access region according to its body stmt + * It will detect the access region as an array in order of appearance in AST * \param block The block to be detected * \param buffer_var_map The outside buffers which may be accessed the block. * It is a map from buffer var to the buffer. @@ -167,8 +167,19 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constrain * - second: write regions * - third: opaque regions */ -Array> GetBlockAccessRegion(const Block& block, - const Map& buffer_var_map); +TVM_DLL Array> GetBlockAccessRegion(const Block& block, + const Map& buffer_var_map); + +/*! + * \brief Auto detect the block read/write region according to its body stmt. An opaque access will + * be counted as both a read and a write access + * \param block The block to be detected + * \param buffer_var_map The outside buffers which may be accessed the block. + * It is a map from buffer var to the buffer + * \return An array only consisting of the read regions and write regions of the input block + */ +TVM_DLL Array> GetBlockReadWriteRegion(const Block& block, + const Map& buffer_var_map); /*! * \brief Calculate the expresion complexity based on number of symbols it contains. diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 500195ac9a13..d1aaa61c3aae 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -136,7 +136,29 @@ def get_block_access_region( - second: write regions - third: opaque regions """ - return _ffi_api.get_block_access_region(block, buffer_var_map) # type: ignore + return _ffi_api.GetBlockAccessRegion(block, buffer_var_map) # type: ignore + + +def get_block_read_write_region( + block: Block, buffer_var_map: Dict[Var, Buffer] +) -> List[List[BufferRegion]]: + """Auto detect the block read/write region according to its body stmt. + An opaque access will be counted as both a read and a write access + + Parameters + ---------- + block: tvm.tir.Block + The block in which we are detecting read/write regions. + + buffer_var_map : Dict[Var, Buffer] + The outside buffers which may access the block. Mapping from buffer var to the buffer + + Returns + ------- + result : List[List[BufferRegion]] + An array only consisting of the read regions and write regions of the input block + """ + return _ffi_api.GetBlockReadWriteRegion(block, buffer_var_map) # type: ignore def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int) -> int: diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index dd01aed61c52..90aaa35d60d8 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -285,7 +285,39 @@ Array> GetBlockAccessRegion(const Block& block, return {detector.CollectReads(), detector.CollectWrites(), detector.CollectOpaques()}; } -TVM_REGISTER_GLOBAL("tir.analysis.get_block_access_region").set_body_typed(GetBlockAccessRegion); +Array> GetBlockReadWriteRegion(const Block& block, + const Map& buffer_var_map) { + // Step 1. Get all the read/write/opaque accesses in the input block. + Array> access_regions = GetBlockAccessRegion(block, buffer_var_map); + // Step 2. Collect all the buffers that are opaquely accessed. + std::unordered_set opaque_accessed_buffers; + for (const BufferRegion& opaque_access : access_regions[2]) { + opaque_accessed_buffers.insert(opaque_access->buffer.get()); + } + // Step 3. Create new arrays of read/write regions. + Array new_read_regions; + Array new_write_regions; + new_read_regions.reserve(access_regions[0].size() + access_regions[2].size()); + new_write_regions.reserve(access_regions[1].size() + access_regions[2].size()); + for (const BufferRegion& read_access : access_regions[0]) { + if (!opaque_accessed_buffers.count(read_access->buffer.get())) { + new_read_regions.push_back(read_access); + } + } + for (const BufferRegion& write_access : access_regions[1]) { + if (!opaque_accessed_buffers.count(write_access->buffer.get())) { + new_write_regions.push_back(write_access); + } + } + for (const BufferRegion& opaque_access : access_regions[2]) { + new_read_regions.push_back(opaque_access); + new_write_regions.push_back(opaque_access); + } + return {new_read_regions, new_write_regions}; +} + +TVM_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion); +TVM_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion").set_body_typed(GetBlockReadWriteRegion); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 2583b21227e4..9c88cc1e787a 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -409,7 +409,7 @@ class BaseInliner : public StmtExprMutator { Array reads = std::move(block->reads); Array writes = std::move(block->writes); if (!is_scope_root) { - Array> inspected = GetBlockAccessRegion(block, buffer_var_map_); + Array> inspected = GetBlockReadWriteRegion(block, buffer_var_map_); reads = std::move(inspected[0]); writes = std::move(inspected[1]); } diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 59f9170786b6..97153aedc6a3 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -129,7 +129,10 @@ class BufferAllocationLocator : public StmtExprMutator { /*init=*/NullOpt, /*alloc_buffers=*/alloc_buffers); ObjectPtr n = CopyOnWrite(opaque_block.get()); - CollectReadWrite(opaque_block, &n->reads, &n->writes); + Array> access = + GetBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); + n->reads = access[0]; + n->writes = access[1]; BlockRealize realize({}, Bool(true), Block(n)); return std::move(realize); } @@ -144,17 +147,6 @@ class BufferAllocationLocator : public StmtExprMutator { return result; } - void CollectReadWrite(const Block& block, Array* reads, - Array* writes) const { - Array> access = GetBlockAccessRegion(block, buffer_data_to_buffer_); - *reads = access[0]; - *writes = access[1]; - for (const auto& opaque_access : access[2]) { - reads->push_back(opaque_access); - writes->push_back(opaque_access); - } - } - /*! \brief The map from stmt to the buffers to be allocated under it. */ std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 9c95b9819e6f..bc421aa4d19b 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import tir, script from tvm.ir import Range @@ -81,6 +82,20 @@ def opaque_block_func() -> None: B[i, j] = A[i, j] + 1.0 +@tvm.script.tir +def opaque_access_func() -> None: + A = tir.alloc_buffer([1024]) + B = tir.alloc_buffer([1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [v]: + tir.bind(v, i) + tir.reads([A[v * 128 : v * 128 + 128]]) + tir.writes([B[v * 128 : v * 128 + 128]]) + tir.evaluate( + tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32") + ) + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -110,6 +125,19 @@ def test_opaque_block(): tvm.ir.assert_structural_equal(block1.writes, ret[1]) +def test_opaque_access(): + block = opaque_access_func.body.block.body.body.block + alloc_buffers = opaque_access_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(ret0[0], ret1[0]) + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(ret0[1], ret1[1]) + + def test_match_buffer(): root_block = match_buffer_func.body.block block = root_block.body.body.body.block @@ -141,4 +169,5 @@ def test_match_buffer(): if __name__ == "__main__": test_block_access_region_detector() test_opaque_block() + test_opaque_access() test_match_buffer() diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index 067952899c0a..bc054938d282 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -17,7 +17,6 @@ # pylint: disable=missing-function-docstring,missing-module-docstring import sys -import numpy as np import pytest import tvm import tvm.testing