From 5e76a6d9291d3046953ecdccbcf2bac7ebc9b69f Mon Sep 17 00:00:00 2001 From: "fengrong.jia" Date: Mon, 25 Jul 2022 13:58:22 +0800 Subject: [PATCH] [TIR Pass] decouple flatten buffer to lower opaque block pass and flatten buffer. --- include/tvm/tir/transform.h | 11 +- python/tvm/script/tir/scope_handler.py | 10 +- python/tvm/tir/transform/transform.py | 16 +- src/driver/driver_api.cc | 1 + src/meta_schedule/postproc/verify_gpu_code.cc | 1 + src/tir/transforms/flatten_buffer.cc | 135 +------ src/tir/transforms/lower_opaque_block.cc | 177 ++++++++++ tests/python/unittest/test_tir_buffer.py | 1 + .../test_tir_transform_flatten_buffer.py | 261 ++++---------- ...est_tir_transform_inject_ptx_async_copy.py | 2 + .../test_tir_transform_loop_partition.py | 1 + .../test_tir_transform_lower_opaque_block.py | 329 ++++++++++++++++++ 12 files changed, 615 insertions(+), 330 deletions(-) create mode 100644 src/tir/transforms/lower_opaque_block.cc create mode 100644 tests/python/unittest/test_tir_transform_lower_opaque_block.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 005bf8410376..c758a00b3f0f 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -457,9 +457,14 @@ TVM_DLL Pass LegalizePackedCalls(); TVM_DLL Pass LowerMatchBuffer(); /*! - * \brief Flatten the multi-dimensional BufferLoad and BufferStore - * to single dimensional Load/Store. Also remove Block to - * ensure that the flattened TIR can not be scheduled again. + * \brief Remove the block to ensure that the TIR can not be scheduled again. + * \return The pass. + */ +TVM_DLL Pass LowerOpaqueBlock(); + +/*! + * \brief Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional + * BufferLoad/BufferStore for the TIR not contains opaque block. * \return The pass. */ TVM_DLL Pass FlattenBuffer(); diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 76fbf26eea31..92aaf8b4d992 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -111,16 +111,10 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None) condition = tvm.runtime.convert(condition) scope = tvm.runtime.convert(scope) - # Currently, allocate nodes should only occur after buffer - # flattening has been applied. This can be simplified in - # the future by having the AllocateNode hold a buffer - # object directly. - flattened = self.buffer.get_flattened_buffer() - return tvm.tir.Allocate( self.buffer.data, - flattened.dtype, - flattened.shape, + self.buffer.dtype, + self.buffer.shape, condition, self.body, annotations=annotations, diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 2a4ff6618a7f..6cc7b2e1f885 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -769,10 +769,20 @@ def LowerMatchBuffer(): return _ffi_api.LowerMatchBuffer() # type: ignore +def LowerOpaqueBlock(): + """Remove the block to ensure that the TIR can not be scheduled again. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerOpaqueBlock() # type: ignore + + def FlattenBuffer(): - """Flatten the multi-dimensional BufferLoad and BufferStore - to single dimensional Load/Store. Also remove Block to - ensure that the flattened TIR can not be scheduled again. + """Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional + BufferLoad/BufferStore for the TIR not contains opaque block. Returns ------- diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0446347eca2c..6f4fb618d334 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -202,6 +202,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); + pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::LowerVtcmAlloc()); pass_list.push_back(tir::transform::BF16Legalize()); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 857b732c9804..dfe2c5a06a17 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -164,6 +164,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); + pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 21de191db009..dcc23a72b27b 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -21,32 +21,17 @@ * \file flatten_buffer.cc */ -#include -#include -#include #include #include -#include "../../support/utils.h" #include "ir_utils.h" namespace tvm { namespace tir { -PrimExpr BufferArea(const Buffer& buffer) { - if (buffer->strides.size()) { - ICHECK(buffer->shape.size() == buffer->strides.size()); - return buffer->strides[0] * buffer->shape[0]; - } - PrimExpr area = Integer(1); - for (const PrimExpr& dim : buffer->shape) { - area = area * dim; - } - return area; -} - /*! * \brief Transform multi-dimension BufferLoad/BufferStore into device-supported dimension + * for the TIR not contains opaque block. */ class BufferFlattener : public StmtExprMutator { public: @@ -68,76 +53,25 @@ class BufferFlattener : public StmtExprMutator { } } - Stmt VisitStmt_(const BlockRealizeNode* op) final { - // We have convert blocks into opaque blocks in previous passes. - ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please " - "call pass ConvertBlocksToOpaque before."; - // Step 1. Visit the body - Block new_block = Downcast(this->VisitStmt(op->block)); - PrimExpr predicate = this->VisitExpr(op->predicate); - // Step 2. Transform the `predicate` to if-then-else - Stmt body = new_block->body; - if (!is_one(predicate)) { - body = IfThenElse(predicate, std::move(body)); - } - // Step 3. Handle allocations in reverse order - for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { - Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]); - body = Allocate(buffer->data, buffer->dtype, buffer->shape, const_true(), std::move(body)); - } - return body; - } - - Stmt VisitStmt_(const ForNode* op) final { - // Step 1. Update unit loop info. - PrimExpr min = this->VisitExpr(op->min); - PrimExpr extent = this->VisitExpr(op->extent); - if (is_one(extent) && op->annotations.empty()) { - // handling unit loop - unit_loop_vars_[op->loop_var] = min; - } - // Step 2. Visit recursively - Stmt body = this->VisitStmt(op->body); - // Step 3. Create new For loop accordingly - if (op->kind == ForKind::kThreadBinding) { - // Case 1. Thread binding - ICHECK(op->thread_binding.defined()); - String thread_tag = op->thread_binding.value()->thread_tag; - body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); - } else if (is_one(extent) && op->annotations.empty()) { - // Case 2. Unit loop - return body; - } else { - // Case 3. An ordinary loop - body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body)); - } - // Step 4. Handle annotations - std::set ordered_ann_keys; - for (const auto& annotation : op->annotations) { - ordered_ann_keys.insert(annotation.first); - } - for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend(); ++it) { - const std::string& ann_key = *it; - const ObjectRef& ann_value = op->annotations.at(ann_key); - if (attr::IsPragmaKey(ann_key)) { - body = - AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key, ann_value), std::move(body)); - } + Stmt VisitStmt_(const AllocateNode* op) final { + Allocate alloc = Downcast(StmtExprMutator::VisitStmt_(op)); + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (alloc->dtype == DataType::Bool()) { + auto writer = alloc.CopyOnWrite(); + writer->dtype = DataType::Int(8); } - return body; - } - - PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); - auto it = unit_loop_vars_.find(var); - if (it == unit_loop_vars_.end()) { - return std::move(var); + // Handle multi-dimension allocations + if (alloc->extents.size() == 1) { + return std::move(alloc); } else { - PrimExpr expr = it->second; - if (expr.dtype() != var.dtype()) { - expr = tvm::cast(var.dtype(), std::move(expr)); + Array flat_extent(static_cast(1), 1); + for (size_t i = 0; i < alloc->extents.size(); i++) { + flat_extent.Set(0, flat_extent[0] * alloc->extents[i]); } - return expr; + auto n = alloc.CopyOnWrite(); + n->extents = flat_extent; + return std::move(alloc); } } @@ -146,7 +80,6 @@ class BufferFlattener : public StmtExprMutator { if (it != buffer_remap_.end()) { return it->second; } - auto flattened = buf.GetFlattenedBuffer(); // TODO(Lunderberg): Move the handling of boolean into a @@ -208,40 +141,6 @@ class BufferFlattener : public StmtExprMutator { return node; } - static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag, - Stmt body) { - IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), - /*var=*/std::move(var), - /*iter_type=*/IterVarType::kThreadIndex, - /*thread_tag=*/thread_tag); - String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || - thread_tag == "vthread.y" || thread_tag == "vthread.z") - ? attr::virtual_thread - : attr::thread_extent; - return AttrStmt(/*node=*/std::move(iter_var), - /*attr_key=*/std::move(attr_key), - /*value=*/std::move(extent), - /*body=*/std::move(body)); - } - - /*! \brief Convert attr value from annotation map into PrimExpr. */ - PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) { - if (!obj.defined()) { - return PrimExpr(); - } else if (const PrimExprNode* expr = obj.as()) { - return GetRef(expr); - } else if (const StringObj* str = obj.as()) { - return std::move(StringImm(str->data)); - } else { - LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj->GetTypeKey() - << " not supported"; - return PrimExpr(); - } - } - - /*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */ - std::unordered_map unit_loop_vars_; - /*! \brief Map of buffers being remapped. */ std::unordered_map buffer_remap_; diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc new file mode 100644 index 000000000000..69d8787aa1a1 --- /dev/null +++ b/src/tir/transforms/lower_opaque_block.cc @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_opaque_block.cc + */ + +#include +#include + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Remove Block to ensure that the TIR can not be scheduled again. + */ +class OpaqueBlockLower : public StmtExprMutator { + private: + Stmt VisitStmt_(const BlockRealizeNode* op) final { + // We have convert blocks into opaque blocks in previous passes. + ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please " + "call pass ConvertBlocksToOpaque before."; + // Step 1. Visit the body + Block new_block = Downcast(this->VisitStmt(op->block)); + PrimExpr predicate = this->VisitExpr(op->predicate); + // Step 2. Transform the `predicate` to if-then-else + Stmt body = new_block->body; + if (!is_one(predicate)) { + body = IfThenElse(predicate, std::move(body)); + } + // Step 3. Handle allocations in reverse order + for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { + const Buffer& buffer = new_block->alloc_buffers[i - 1]; + Array new_shape = buffer->shape; + if (buffer->strides.size()) { + ICHECK_EQ(buffer->shape.size(), buffer->strides.size()); + for (size_t i = buffer->strides.size() - 1; i > 0; --i) { + ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i]))); + new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]); + } + } + body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), std::move(body)); + } + return body; + } + + Stmt VisitStmt_(const ForNode* op) final { + // Step 1. Update unit loop info. + PrimExpr min = this->VisitExpr(op->min); + PrimExpr extent = this->VisitExpr(op->extent); + if (is_one(extent) && op->annotations.empty()) { + // handling unit loop + unit_loop_vars_[op->loop_var] = min; + } + // Step 2. Visit recursively + Stmt body = this->VisitStmt(op->body); + // Step 3. Create new For loop accordingly + if (op->kind == ForKind::kThreadBinding) { + // Case 1. Thread binding + ICHECK(op->thread_binding.defined()); + String thread_tag = op->thread_binding.value()->thread_tag; + body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); + } else if (is_one(extent) && op->annotations.empty()) { + // Case 2. Unit loop + return body; + } else { + // Case 3. An ordinary loop + body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body)); + } + // Step 4. Handle annotations + std::set ordered_ann_keys; + for (const auto& annotation : op->annotations) { + ordered_ann_keys.insert(annotation.first); + } + for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend(); ++it) { + const std::string& ann_key = *it; + const ObjectRef& ann_value = op->annotations.at(ann_key); + if (attr::IsPragmaKey(ann_key)) { + body = + AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key, ann_value), std::move(body)); + } + } + return body; + } + + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto it = unit_loop_vars_.find(var); + if (it == unit_loop_vars_.end()) { + return std::move(var); + } else { + PrimExpr expr = it->second; + if (expr.dtype() != var.dtype()) { + expr = tvm::cast(var.dtype(), std::move(expr)); + } + return expr; + } + } + + static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag, + Stmt body) { + IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), + /*var=*/std::move(var), + /*iter_type=*/IterVarType::kThreadIndex, + /*thread_tag=*/thread_tag); + String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || + thread_tag == "vthread.y" || thread_tag == "vthread.z") + ? attr::virtual_thread + : attr::thread_extent; + return AttrStmt(/*node=*/std::move(iter_var), + /*attr_key=*/std::move(attr_key), + /*value=*/std::move(extent), + /*body=*/std::move(body)); + } + + /*! \brief Convert attr value from annotation map into PrimExpr. */ + PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) { + if (!obj.defined()) { + return PrimExpr(); + } else if (const PrimExprNode* expr = obj.as()) { + return GetRef(expr); + } else if (const StringObj* str = obj.as()) { + return std::move(StringImm(str->data)); + } else { + LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj->GetTypeKey() + << " not supported"; + return PrimExpr(); + } + } + + /*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */ + std::unordered_map unit_loop_vars_; +}; + +PrimFunc LowerOpaqueBlock(PrimFunc f) { + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + auto fptr = f.CopyOnWrite(); + fptr->body = OpaqueBlockLower()(std::move(fptr->body)); + return f; + } else { + return f; + } +} + +namespace transform { + +Pass LowerOpaqueBlock() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return LowerOpaqueBlock(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerOpaqueBlock", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerOpaqueBlock").set_body_typed(LowerOpaqueBlock); +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 10e827978cc0..d250fada6ae4 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -115,6 +115,7 @@ def test_buffer_vload_nullptr(): [ tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), tvm.tir.transform.CompactBufferAllocation(), + tvm.tir.transform.LowerOpaqueBlock(), tvm.tir.transform.FlattenBuffer(), ] )(mod) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index f1a33a4fb203..ea9c604e718a 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te, tir +import tvm.testing +from tvm import te from tvm.script import tir as T @@ -28,24 +29,15 @@ def _check(original, transformed): @T.prim_func -def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: +def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") - for i in range(0, 16): - with T.block(): - T.reads(A[i, 0:16]) - T.writes(C[i, 0:16]) - B = T.alloc_buffer([1, 16], "float32", scope="global") - for j in range(0, 16): - with T.block(): - T.reads(A[i, j]) - T.writes(B[0, j]) - B[0, j] = A[i, j] + 1.0 - for j in range(0, 16): - with T.block(): - T.reads(B[0, j]) - T.writes(C[i, j]) - C[i, j] = B[0, j] * 2.0 + for i in T.serial(0, 16): + B_new = T.allocate([1, 16], "float32", "global") + for j in T.serial(0, 16): + B_new[0, j] = A[i, j] + 1.0 + for j in T.serial(0, 16): + C[i, j] = B_new[0, j] * 2.0 @T.prim_func @@ -63,26 +55,22 @@ def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: @T.prim_func -def compacted_gpu_func(a: T.handle, c: T.handle) -> None: +def gpu_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") - for i0 in T.thread_binding(0, 4, thread="blockIdx.x"): - for i1 in T.thread_binding(0, 2, thread="threadIdx.x"): - for i2 in T.thread_binding(0, 2, thread="vthread"): - with T.block(): - T.reads(A[i0 * 4 + i1 * 2 + i2, 0:16]) - T.writes(C[i0 * 4 + i1 * 2 + i2, 0:16]) - B = T.alloc_buffer([1, 16], "float32", scope="local") - for j in range(0, 16): - with T.block(): - T.reads(A[i0 * 4 + i1 * 2 + i2, j]) - T.writes(B[0, j]) - B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 - for j in range(0, 16): - with T.block(): - T.reads(B[0, j]) - T.writes(C[i0 * 4 + i1 * 2 + i2, j]) - C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 + + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") + + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.allocate([1, 16], "float32", "local") + for j in range(0, 16): + B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 + for j in range(0, 16): + C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 @T.prim_func @@ -107,25 +95,16 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None: @T.prim_func -def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: +def symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: A = T.match_buffer(a, (n, m), "float32") C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - with T.block(): - T.reads(A[i, m]) - T.writes(C[i, m]) - B = T.alloc_buffer((m,), "float32", scope="global") - for j in range(0, m): - with T.block(): - T.reads(A[i, j]) - T.writes(B[j]) - B[j] = A[i, j] + 1.0 - for j in range(0, m): - with T.block(): - T.reads(B[j]) - T.writes(C[i, j]) - C[i, j] = B[j] * 2.0 + B = T.allocate([m], "float32", "global") + for j in range(0, m): + B[j] = A[i, j] + 1.0 + for j in range(0, m): + C[i, j] = B[j] * 2.0 @T.prim_func @@ -144,105 +123,44 @@ def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> @T.prim_func -def compacted_predicate_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - C = T.match_buffer(c, (32), "float32") - - for i, j in T.grid(5, 7): - with T.block(): - T.reads(A[i * 7 + j]) - T.writes(C[i * 7 + j]) - T.where(i * 7 + j < 32) - C[i * 7 + j] = A[i * 7 + j] + 1.0 - - -@T.prim_func -def flattened_predicate_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - C = T.match_buffer(c, (32), "float32") - T.preflattened_buffer(A, (32), "float32", data=A.data) - T.preflattened_buffer(C, (32), "float32", data=C.data) - - for i, j in T.grid(5, 7): - if i * 7 + j < 32: - C[i * 7 + j] = A[i * 7 + j] + 1.0 - - -@T.prim_func -def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - C = T.match_buffer(c, (32), "float32") - - for x, y, z in T.grid(4, 1, 8): - with T.block(): - T.reads(A[x * 8 + y * 8 + z]) - T.writes(C[x * 8 + y * 8 + z]) - C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0 - - -@T.prim_func -def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - C = T.match_buffer(c, (32), "float32") - T.preflattened_buffer(A, (32), "float32", data=A.data) - T.preflattened_buffer(C, (32), "float32", data=C.data) - - for x, z in T.grid(4, 8): - C[x * 8 + z] = A[x * 8 + z] + 1.0 - - -@T.prim_func -def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - D = T.match_buffer(d, (32), "float32") +def multi_alloc_func(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (4, 32), "float32") + D = T.match_buffer(d, (4, 32), "float32") - for i in range(0, 32): - with T.block(): - T.reads(A[i]) - T.writes(D[i]) - B = T.alloc_buffer((32,), scope="global") - C = T.alloc_buffer((32,), scope="global") - B[i] = A[i] + 1.0 - C[i] = A[i] + B[i] - D[i] = C[i] * 2.0 + for i, j in T.grid(4, 32): + B = T.allocate((4, 32), "float32", scope="global") + C = T.allocate((4, 32), "float32", scope="global") + B[i, j] = A[i, j] + 1.0 + C[i, j] = A[i, j] + B[i, j] + D[i, j] = C[i, j] * 2.0 @T.prim_func def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, (32), "float32") - D = T.match_buffer(d, (32), "float32") - T.preflattened_buffer(A, (32), "float32", data=A.data) - T.preflattened_buffer(D, (32), "float32", data=D.data) + A = T.match_buffer(a, (128), "float32") + D = T.match_buffer(d, (128), "float32") + T.preflattened_buffer(A, (4, 32), "float32", data=A.data) + T.preflattened_buffer(D, (4, 32), "float32", data=D.data) - for i in range(0, 32): - B = T.allocate((32,), "float32", "global") - C = T.allocate((32,), "float32", "global") - B[i] = A[i] + 1.0 - C[i] = A[i] + B[i] - D[i] = C[i] * 2.0 + for i, j in T.grid(4, 32): + B = T.allocate((128), "float32", "global") + C = T.allocate((128), "float32", "global") + B[i * 32 + j] = A[i * 32 + j] + 1.0 + C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] + D[i * 32 + j] = C[i * 32 + j] * 2.0 @T.prim_func -def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: +def strided_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") - for i0 in range(0, 4): - with T.block(): - T.reads(A[i0 * 4 : i0 * 4 + 4, 0:16]) - T.writes(C[i0 * 4 : i0 * 4 + 4, 0:16]) - B = T.alloc_buffer([4, 16], "float32", strides=[17, 1], scope="global") - for i1 in range(0, 4): - for j in range(0, 16): - with T.block(): - T.reads(A[i0 * 4 + i1, j]) - T.writes(B[i1, j]) - B[i1, j] = A[i0 * 4 + i1, j] + 1.0 - for i1 in range(0, 4): - for j in range(0, 16): - with T.block(): - T.reads(B[i1, j]) - T.writes(C[i0 * 4 + i1, j]) - C[i0 * 4 + i1, j] = B[i1, j] * 2.0 + for i0 in T.serial(4): + B = T.allocate([4, 17], "float32", "global") + B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) + for i1, j in T.grid(4, 16): + B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0 + for i1, j in T.grid(4, 16): + C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 @T.prim_func @@ -261,20 +179,10 @@ def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 -@T.prim_func -def annotated_loops(a: T.handle) -> None: - A = T.match_buffer(a, (16,), "float32") - for i in range(0, 16, annotations={"pragma_1": "str_value", "pragma_2": 1, "pragma_3": 0.0}): - A[i] = 0.0 - - @T.prim_func def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None: for i0 in T.serial(10): - with T.block("b"): - T.reads(a[i0]) - T.writes(b[i0]) - b[i0] = a[i0] + b[i0] = a[i0] @T.prim_func @@ -286,41 +194,24 @@ def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> b[i0] = T.cast(T.cast(a[i0], "bool"), "int8") -@T.prim_func -def boolean_handle_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None: - T.preflattened_buffer(a, [10], dtype="bool", data=a.data) - T.preflattened_buffer(b, [10], dtype="bool", data=b.data) - # body - for i0 in T.serial(10): - b[i0] = T.cast(T.cast(a[i0], "bool"), "int8") - - def test_elementwise(): - _check(compacted_elementwise_func, flattened_elementwise_func) + _check(elementwise_func, flattened_elementwise_func) def test_gpu_workload(): - _check(compacted_gpu_func, flattened_gpu_func) + _check(gpu_func, flattened_gpu_func) def test_symbolic_shape(): - _check(compacted_symbolic_func, flattened_symbolic_func) - - -def test_predicate(): - _check(compacted_predicate_func, flattened_predicate_func) - - -def test_unit_loops(): - _check(compacted_unit_loop_func, flattened_unit_loop_func) + _check(symbolic_func, flattened_symbolic_func) def test_multi_alloc(): - _check(compacted_multi_alloc_func, flattened_multi_alloc_func) + _check(multi_alloc_func, flattened_multi_alloc_func) def test_strided_buffer(): - _check(compacted_strided_buffer_func, flattened_strided_buffer_func) + _check(strided_buffer_func, flattened_strided_buffer_func) def test_lower_te(): @@ -332,35 +223,9 @@ def test_lower_te(): tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE -def test_annotated_loops(): - mod = tvm.IRModule.from_expr(annotated_loops) - mod = tvm.tir.transform.FlattenBuffer()(mod) - # _check(annotated_loops, compacted_annotated_loops) - attr1 = mod["main"].body - attr2 = attr1.body - attr3 = attr2.body - assert attr1.attr_key == "pragma_1" and attr1.value == "str_value" - assert attr2.attr_key == "pragma_2" - tvm.ir.assert_structural_equal(attr2.value, tvm.tir.IntImm("int32", 1)) - assert attr3.attr_key == "pragma_3" - tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0)) - - def test_boolean_handling(): _check(boolean_handling_before, boolean_handling_after) - # mod = tvm.IRModule.from_expr(boolean_handling_before) - # mod = tvm.tir.transform.FlattenBuffer()(mod) - # print(mod.script()) if __name__ == "__main__": - test_elementwise() - test_gpu_workload() - test_symbolic_shape() - test_predicate() - test_unit_loops() - test_multi_alloc() - test_strided_buffer() - test_lower_te() - test_annotated_loops() - test_boolean_handling() + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index d7e13f40aa14..1a906b2fb66e 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -127,6 +127,7 @@ def test_inject_async_copy(): f = generate_global_to_shared_vectorized_copy(dtype, vec_size) mod = tvm.IRModule.from_expr(f) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) mod = tvm.tir.transform.FlattenBuffer()(mod) if vec_size > 1: mod = tvm.tir.transform.VectorizeLoop()(mod) @@ -154,6 +155,7 @@ def test_inject_async_copy_shared_dyn(): f = ptx_global_to_shared_dyn_copy_fp16x8 mod = tvm.IRModule.from_expr(f) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) mod = tvm.tir.transform.FlattenBuffer()(mod) mod = tvm.tir.transform.VectorizeLoop()(mod) mod = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()(mod) diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 6cfe96664d89..86f2b6696b3d 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -611,6 +611,7 @@ def concat_func_3( def test_condition_mutually_exclusive(): mod = IRModule.from_expr(concat_func_3) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) mod = tvm.tir.transform.FlattenBuffer()(mod) mod = tvm.tir.transform.LoopPartition()(mod) mod = tvm.tir.transform.Simplify()(mod) diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py new file mode 100644 index 000000000000..9b18c407c40c --- /dev/null +++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py @@ -0,0 +1,329 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import te +from tvm.script import tir as T + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +@T.prim_func +def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with T.block(): + T.reads(A[i, 0:16]) + T.writes(C[i, 0:16]) + B = T.alloc_buffer([1, 16], "float32", scope="global") + for j in range(0, 16): + with T.block(): + T.reads(A[i, j]) + T.writes(B[0, j]) + B[0, j] = A[i, j] + 1.0 + for j in range(0, 16): + with T.block(): + T.reads(B[0, j]) + T.writes(C[i, j]) + C[i, j] = B[0, j] * 2.0 + + +@T.prim_func +def transformed_elementwise_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + for i in T.serial(0, 16): + B_new = T.allocate([1, 16], "float32", "global") + for j in T.serial(0, 16): + B_new[0, j] = A[i, j] + 1.0 + for j in T.serial(0, 16): + C[i, j] = B_new[0, j] * 2.0 + + +@T.prim_func +def compacted_gpu_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + for i0 in T.thread_binding(0, 4, thread="blockIdx.x"): + for i1 in T.thread_binding(0, 2, thread="threadIdx.x"): + for i2 in T.thread_binding(0, 2, thread="vthread"): + with T.block(): + T.reads(A[i0 * 4 + i1 * 2 + i2, 0:16]) + T.writes(C[i0 * 4 + i1 * 2 + i2, 0:16]) + B = T.alloc_buffer([1, 16], "float32", scope="local") + for j in range(0, 16): + with T.block(): + T.reads(A[i0 * 4 + i1 * 2 + i2, j]) + T.writes(B[0, j]) + B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 + for j in range(0, 16): + with T.block(): + T.reads(B[0, j]) + T.writes(C[i0 * 4 + i1 * 2 + i2, j]) + C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 + + +@T.prim_func +def transformed_gpu_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") + + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.allocate([1, 16], "float32", "local") + for j in range(0, 16): + B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 + for j in range(0, 16): + C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 + + +@T.prim_func +def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, (n, m), "float32") + C = T.match_buffer(c, (n, m), "float32") + + for i in range(0, n): + with T.block(): + T.reads(A[i, m]) + T.writes(C[i, m]) + B = T.alloc_buffer((m,), "float32", scope="global") + for j in range(0, m): + with T.block(): + T.reads(A[i, j]) + T.writes(B[j]) + B[j] = A[i, j] + 1.0 + for j in range(0, m): + with T.block(): + T.reads(B[j]) + T.writes(C[i, j]) + C[i, j] = B[j] * 2.0 + + +@T.prim_func +def transformed_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, (n, m), "float32") + C = T.match_buffer(c, (n, m), "float32") + + for i in range(0, n): + B = T.allocate([m], "float32", "global") + for j in range(0, m): + B[j] = A[i, j] + 1.0 + for j in range(0, m): + C[i, j] = B[j] * 2.0 + + +@T.prim_func +def compacted_predicate_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + C = T.match_buffer(c, (32), "float32") + + for i, j in T.grid(5, 7): + with T.block(): + T.reads(A[i * 7 + j]) + T.writes(C[i * 7 + j]) + T.where(i * 7 + j < 32) + C[i * 7 + j] = A[i * 7 + j] + 1.0 + + +@T.prim_func +def transformed_predicate_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + C = T.match_buffer(c, (32), "float32") + + for i, j in T.grid(5, 7): + if i * 7 + j < 32: + C[i * 7 + j] = A[i * 7 + j] + 1.0 + + +@T.prim_func +def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + C = T.match_buffer(c, (32), "float32") + + for x, y, z in T.grid(4, 1, 8): + with T.block(): + T.reads(A[x * 8 + y * 8 + z]) + T.writes(C[x * 8 + y * 8 + z]) + C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0 + + +@T.prim_func +def transformed_unit_loop_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + C = T.match_buffer(c, (32), "float32") + + for x, z in T.grid(4, 8): + C[x * 8 + z] = A[x * 8 + z] + 1.0 + + +@T.prim_func +def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + D = T.match_buffer(d, (32), "float32") + + for i in range(0, 32): + with T.block(): + T.reads(A[i]) + T.writes(D[i]) + B = T.alloc_buffer((32,), scope="global") + C = T.alloc_buffer((32,), scope="global") + B[i] = A[i] + 1.0 + C[i] = A[i] + B[i] + D[i] = C[i] * 2.0 + + +@T.prim_func +def transformed_multi_alloc_func(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + D = T.match_buffer(d, (32), "float32") + + for i in range(0, 32): + B = T.allocate((32,), "float32", "global") + C = T.allocate((32,), "float32", "global") + B[i] = A[i] + 1.0 + C[i] = A[i] + B[i] + D[i] = C[i] * 2.0 + + +@T.prim_func +def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + for i0 in range(0, 4): + with T.block(): + T.reads(A[i0 * 4 : i0 * 4 + 4, 0:16]) + T.writes(C[i0 * 4 : i0 * 4 + 4, 0:16]) + B = T.alloc_buffer([4, 16], "float32", strides=[17, 1], scope="global") + for i1 in range(0, 4): + for j in range(0, 16): + with T.block(): + T.reads(A[i0 * 4 + i1, j]) + T.writes(B[i1, j]) + B[i1, j] = A[i0 * 4 + i1, j] + 1.0 + for i1 in range(0, 4): + for j in range(0, 16): + with T.block(): + T.reads(B[i1, j]) + T.writes(C[i0 * 4 + i1, j]) + C[i0 * 4 + i1, j] = B[i1, j] * 2.0 + + +@T.prim_func +def transformed_strided_buffer_func( + A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] +) -> None: + # body + for i0 in T.serial(4): + B = T.allocate([4, 17], "float32", "global") + B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) + for i1, j in T.grid(4, 16): + B_1[i1, j] = A[i0 * 4 + i1, j] + T.float32(1) + for i1, j in T.grid(4, 16): + C[i0 * 4 + i1, j] = B_1[i1, j] * T.float32(2) + + +@T.prim_func +def annotated_loops(a: T.handle) -> None: + A = T.match_buffer(a, (16,), "float32") + for i in range(0, 16, annotations={"pragma_1": "str_value", "pragma_2": 1, "pragma_3": 0.0}): + A[i] = 0.0 + + +@T.prim_func +def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None: + for i0 in T.serial(10): + with T.block("b"): + T.reads(a[i0]) + T.writes(b[i0]) + b[i0] = a[i0] + + +@T.prim_func +def boolean_handling_after(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None: + # body + for i0 in T.serial(10): + b[i0] = a[i0] + + +def test_elementwise(): + _check(compacted_elementwise_func, transformed_elementwise_func) + + +def test_gpu_workload(): + _check(compacted_gpu_func, transformed_gpu_func) + + +def test_symbolic_shape(): + _check(compacted_symbolic_func, transformed_symbolic_func) + + +def test_predicate(): + _check(compacted_predicate_func, transformed_predicate_func) + + +def test_unit_loops(): + _check(compacted_unit_loop_func, transformed_unit_loop_func) + + +def test_multi_alloc(): + _check(compacted_multi_alloc_func, transformed_multi_alloc_func) + + +def test_strided_buffer(): + _check(compacted_strided_buffer_func, transformed_strided_buffer_func) + + +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.LowerOpaqueBlock()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # LowerOpaqueBlock should do nothing on TE + + +def test_annotated_loops(): + mod = tvm.IRModule.from_expr(annotated_loops) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + attr1 = mod["main"].body + attr2 = attr1.body + attr3 = attr2.body + assert attr1.attr_key == "pragma_1" and attr1.value == "str_value" + assert attr2.attr_key == "pragma_2" + tvm.ir.assert_structural_equal(attr2.value, tvm.tir.IntImm("int32", 1)) + assert attr3.attr_key == "pragma_3" + tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0)) + + +def test_boolean_handling(): + _check(boolean_handling_before, boolean_handling_after) + + +if __name__ == "__main__": + tvm.testing.main()