diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py index 52f5f49b0a12..28a2df628c53 100644 --- a/python/tvm/meta_schedule/testing/te_workload.py +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -701,6 +701,74 @@ def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid- return (a, b) +def conv2d_nhwc_f16( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +): + inputs = te.placeholder((N, H, W, CI), name="inputs", dtype="float16") + weight = te.placeholder( + (kernel_size, kernel_size, CI // groups, CO), name="weight", dtype="float16" + ) + batch_size, in_h, in_w, _ = inputs.shape + k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, co: te.sum( + ( + tir.Cast( + value=padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ], + dtype="float32", + ) + * tir.Cast(value=weight[rh, rw, rc, co], dtype="float32") + ), + axis=[rh, rw, rc], + ), + name="conv2d_nhwc", + ) + return (inputs, weight, output) + + +def batch_matmul_nkkm_f16( # pylint: disable=invalid-name,missing-docstring + B: int, + N: int, + M: int, + K: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((B, N, K), name="X", dtype="float16") + y = te.placeholder((B, K, M), name="Y", dtype="float16") + k = te.reduce_axis((0, K), name="k") + z = te.compute( # pylint: disable=invalid-name + (B, N, M), + lambda b, i, j: te.sum( + tir.Cast("float32", x[b][i][k]) * tir.Cast("float32", y[b][k][j]), axis=[k] + ), + name="Z", + ) + return (x, y, z) + + def create_te_workload(name: str, idx: int) -> tir.PrimFunc: workload_func, params = CONFIGS[name] return te.create_prim_func(workload_func(*params[idx])) # type: ignore diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 71ff024217c7..cdb4aa9cfa20 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -87,3 +87,37 @@ def get_tensorize_loop_mapping( TensorizeInfo structure if a valid mapping is found, None otherwise """ return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore + + +@tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo") +class AutoTensorizeMappingInfo(Object): + """Necessary information used to perform transformations for tensorization.""" + + +def get_auto_tensorize_mapping_info( + sch: Schedule, block: BlockRV, desc_func: PrimFunc +) -> Optional[AutoTensorizeMappingInfo]: + """Get mapping info between a target block and an intrinsic description including layout + transformations to apply. + + Parameters + ---------- + sch : Schedule + The schedule to be tensorized + block : BlockRV + The compute block for auto tensorization + desc_func : PrimFunc + The prim func describing the computation to be tensorized + + Returns + ------- + auto_tensorize_mapping_info : Optional[AutoTensorizeMappingInfo] + AutoTensorizeMappingInfo structure if potential mappings found, None otherwise. + + Note + ---- + Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized. + We will need to apply the suggested layout transformations and then match against the tensor + intrinsics. + """ + return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 5adc4f8f1b30..b30cef829f1e 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -707,6 +707,56 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, const tir::PrimFunc& desc_func); +/*!\brief Necessary information used to perform transformations for tensorization */ +class AutoTensorizeMappingInfoNode : public Object { + public: + /*! \brief Possible mappings to apply to block iters */ + Array mappings; + + /* Additional information from AutoTensorizeComparator */ + + /*! \brief Mapping from LHS buffer to RHS buffer */ + Map lhs_buffer_map; + /*! \brief Buffer indices on RHS */ + Map> rhs_buffer_indices; + /*! \brief Block iters on LHS */ + Array lhs_iters; + /*! \brief Block iters on RHS */ + Array rhs_iters; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("mappings", &mappings); + v->Visit("lhs_buffer_map", &lhs_buffer_map); + v->Visit("rhs_buffer_indices", &rhs_buffer_indices); + v->Visit("lhs_iters", &lhs_iters); + v->Visit("rhs_iters", &rhs_iters); + } + + static constexpr const char* _type_key = "tir.schedule.AutoTensorizeMappingInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(AutoTensorizeMappingInfoNode, Object); +}; + +class AutoTensorizeMappingInfo : public ObjectRef { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo, ObjectRef, + AutoTensorizeMappingInfoNode); +}; + +/*! + * \brief Get mapping info between a target block and an intrinsic description including layout + * transformations to apply. + * \param self The schedule state + * \param block_sref The compute block for auto tensorization + * \param desc_func The prim func describing the computation to be tensorized + * \return AutoTensorizeMappingInfo structure if a potential mapping is found, NullOpt otherwise. + * \note Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized. + * We will need to apply the suggested layout transformations and then match against the tensor + * intrinsics. + */ +Optional GetAutoTensorizeMappingInfo(const ScheduleState& self, + const StmtSRef& block_sref, + const PrimFunc& desc_func); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 7def8b8674e1..3ee1ed28b857 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -19,6 +19,7 @@ #include #include +#include "../ir_comparator.h" #include "../utils.h" namespace tvm { @@ -2085,39 +2086,60 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); -Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func) { - arith::Analyzer analyzer; - const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); - // Step 1. Analyze desc_func, extract its block, loops and loop vars - const tir::BlockRealizeNode* desc_block = nullptr; +/*! \brief Auxiliary data structure of information extracted from tensor intrin description */ +struct TensorIntrinDescInfo { + /*! \brief The block of the description function, which is the (unique) direct child of the root + * block. + */ + const BlockRealizeNode* desc_block = nullptr; + /*! \brief The loops of the description function, in the order from outer loops to inner ones. */ std::vector desc_loops; + /*! \brief The loop variables. */ std::unordered_set desc_loop_vars; - const auto* desc_scope_realize = desc_func->body.as(); +}; + +/*! + * \brief Extract auxilary information from the tensor intrin description. + * \param analyze The arithmetic analyzer + * \param desc_func The description PrimFunc + * \return The auxilary information + */ +TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, + const PrimFunc& desc_func) { + TensorIntrinDescInfo info; + const auto* desc_scope_realize = desc_func->body.as(); ICHECK(desc_scope_realize); { - auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, - &analyzer](const ObjectRef& obj) -> bool { + auto f_visit = [&](const ObjectRef& obj) -> bool { // Extract the block - if (const auto* block = obj.as()) { - desc_block = block; + if (const auto* block = obj.as()) { + info.desc_block = block; return false; } - // Extract loops - if (const auto* loop = obj.as()) { - desc_loops.push_back(loop); - desc_loop_vars.insert(loop->loop_var.get()); - if (!analyzer.CanProve(loop->min == 0)) { + // Extract the loops + if (const auto* loop = obj.as()) { + info.desc_loops.push_back(loop); + info.desc_loop_vars.insert(loop->loop_var.get()); + if (!analyzer->CanProve(loop->min == 0)) { return false; } } return true; }; tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); - std::reverse(desc_loops.begin(), desc_loops.end()); - ICHECK(desc_block); + std::reverse(info.desc_loops.begin(), info.desc_loops.end()); + ICHECK(info.desc_block); } + return info; +} + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); // Step 2. Collect loops from block_sref const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); @@ -2138,6 +2160,9 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, std::reverse(block_loops.begin(), block_loops.end()); } // Step 3. Map from block loops to desc block loops + const std::vector& desc_loops = desc_info.desc_loops; + const std::unordered_set& desc_loop_vars = desc_info.desc_loop_vars; + const BlockRealizeNode* desc_block = desc_info.desc_block; ObjectPtr ret = make_object(); const int n_block_vars = block->iter_values.size(); const int n_desc_vars = desc_block->iter_values.size(); @@ -2240,5 +2265,217 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func); }); +/******** Auto Tensorization ********/ + +/*! \brief IndexMap proposer for layout transformation in auto tensorization. */ +class AutoTensorizeMappingProposer { + public: + static Array ProposeMappings(const AutoTensorizeComparator* extractor, + arith::Analyzer* analyzer) { + AutoTensorizeMappingProposer proposer(extractor, analyzer); + proposer.CollectFeasibleSet(); + return proposer.ProposeAllFuseMapping(); + } + + private: + explicit AutoTensorizeMappingProposer(const AutoTensorizeComparator* extractor, + arith::Analyzer* analyzer) + : extractor_(extractor), analyzer_(analyzer) {} + + using VarSet = std::unordered_set; + + void CollectFeasibleSet() { + // Collect the set of potential iter var mapping between the workload and the tensor intrin. + // We analyze the appearance of each variable in the buffer indices of each buffer on LHS and + // RHS. The appearance of a variable in the buffer indices is encoded as bit-masks (BufferMask). + // Variables on the LHS and the RHS with the same bit-mask and the same iter type are potential + // mappings. + // + // For example, consider the conv2d case. We will try to match the workload + // conv2d[n, h, w, c] = sum_{rh, rw, rc} X[n, h + rh, w + rw, c + rc] * W[rh, rw, rc, c] + // against a matmul tensor intrin + // C[m, n] = sum_{k} A[m, k] * B[k, n] + // First we extract the correspondence of the buffers: conv2d <=> C, A <=> X, B <=> W. + // Then for each variable, we extract the buffers where it is used for indexing. + // Take the variable m on the RHS as an example. m is used to index buffer A and C. On the LHS, + // we will find the variables used to index only the exact corresponding buffers conv2d and X + // (the variable is not allowed to index other buffers). In this case, n, h, w is used to index + // both buffer conv2d and W, and not in other buffers. Therefore, {n, h, w} <=> m is a potential + // mapping. + + // Note: the mapping is not unique when multiple variables on RHS has the same bit-mask. + // This is currently not supported. + + using BufferMask = std::vector; + + // Step 1: Assign an index to each buffer in LHS and RHS + std::unordered_map rhs_buffer_index; + std::unordered_map lhs_buffer_index; + { + int i = 0; + for (const auto& kv : extractor_->rhs_buffer_map_) { + const Buffer& rhs_buffer = kv.first; + const Buffer& lhs_buffer = kv.second; + rhs_buffer_index[rhs_buffer] = i; + lhs_buffer_index[lhs_buffer] = i; + ++i; + } + } + + // Step 2: Compute the buffer mask + ICHECK_EQ(rhs_buffer_index.size(), lhs_buffer_index.size()); + int num_buffers = rhs_buffer_index.size(); + std::unordered_map> rhs_buffer_masks, lhs_buffer_masks; + // helper function to initialize or update the buffer mask + auto update_mask = [&](const VarNode* var, + std::unordered_map>* masks, int i) { + if (!masks->count(var)) { + (*masks)[var].resize(num_buffers); + } + (*masks)[var][i] = true; + }; + + for (const auto& it : extractor_->rhs_buffer_indices_map_) { + const Buffer& rhs_buffer = it.first; + for (const PrimExpr& rhs_index : it.second) { + if (const VarNode* var_node = rhs_index.as()) { + update_mask(var_node, &rhs_buffer_masks, rhs_buffer_index.at(rhs_buffer)); + } else { + LOG(FATAL) << "ValueError: Buffer index " << rhs_index + << " other that variables in tensor intrinsics is not supported."; + } + } + + auto lhs_buffer_it = extractor_->rhs_buffer_map_.find(rhs_buffer); + ICHECK(lhs_buffer_it != extractor_->rhs_buffer_map_.end()); + const Buffer& lhs_buffer = lhs_buffer_it->second; + for (const PrimExpr& index : extractor_->lhs_buffer_indices_map_.at(lhs_buffer)) { + PreOrderVisit(index, [&](const ObjectRef& obj) -> bool { + if (const VarNode* var = obj.as()) { + update_mask(var, &lhs_buffer_masks, lhs_buffer_index.at(lhs_buffer)); + } + return true; + }); + } + } + + // Step 3: Find variables on LHS and RHS with the same buffer mask. Ensure LHS and RHS vars + // have the same iter type. + std::unordered_map mask_to_rhs_vars; + for (const auto& kv : rhs_buffer_masks) { + const VarNode* rhs_var = kv.first; + const BufferMask& mask = kv.second; + mask_to_rhs_vars[mask].insert(GetRef(rhs_var)); + } + std::unordered_map rhs_var_iter_type; + for (const auto& iter : extractor_->rhs_iters_) { + rhs_var_iter_type.emplace(iter->var.get(), iter->iter_type); + } + for (const auto& iter : extractor_->lhs_iters_) { + auto& potential_mappings = lhs_feasible_vars_[iter->var]; + VarSet rhs_candidates = mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]]; + std::copy_if( + rhs_candidates.begin(), rhs_candidates.end(), + std::inserter(potential_mappings, potential_mappings.begin()), + [&](const Var& var) { return rhs_var_iter_type.at(var.get()) == iter->iter_type; }); + } + } + + Array ProposeAllFuseMapping() { + // Now we have calcuated potential mapping for each iter var on LHS. For iters on LHS mapped to + // the same iter on RHS, they will be fused in the original order in LHS block iters. We will + // generate IndexMap to represent such fusion on LHS. For example, if n, h, w on LHS are mapped + // to the same iter var on RHS, we will produce index map `lambda n, h, w: fuse(n, h, w)`, where + // fuse(v0, .., vn) = ((v0 * v1_extent + v1) + ... ) * vn_extent + vn + + // the parameters of the result index map, each parameter corresponds to a LHS iter + Array index_map_src; + // the outputs of the result index map + Array index_map_tgt; + + // Step 1: Collect extents of LHS iters and prepare the initial indices of the IndexMap + Map lhs_iter_extents; + for (const auto& iter : extractor_->lhs_iters_) { + lhs_iter_extents.Set(iter->var, iter->dom->extent); + index_map_src.push_back(iter->var.copy_with_suffix("")); + } + + // Step 2: Each iter on RHS has a group of corresponding iters on LHS. Initialize the fusion + // result for each group of iters on LHS. + Map fused_lhs_iters; + for (const auto& iter : extractor_->rhs_iters_) { + fused_lhs_iters.Set(iter->var, 0); + } + + // Step 3: Fuse LHS iters mapped to the same RHS iter + for (size_t i = 0; i < extractor_->lhs_iters_.size(); ++i) { + const Var& lhs_iter_var = extractor_->lhs_iters_[i]->var; + const VarSet& rhs_candidates = lhs_feasible_vars_[lhs_iter_var]; + if (rhs_candidates.empty()) { + // put unmapped iters at the beginning + index_map_tgt.push_back(index_map_src[i]); + } else if (rhs_candidates.size() == 1) { + Var rhs_var = *rhs_candidates.begin(); + PrimExpr fused_lhs = fused_lhs_iters.at(rhs_var); + PrimExpr updated_fused_lhs = + fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i]; + fused_lhs_iters.Set(rhs_var, updated_fused_lhs); + } else { + // non-unique mapping is not supported + return {}; + } + } + for (const auto& iter : extractor_->rhs_iters_) { + index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var])); + } + // At most one mapping is supported. + return {IndexMap(index_map_src, index_map_tgt)}; + } + + private: + // The extractor that has extracted information for auto tensorization from the workload and the + // tensor intrin. + const AutoTensorizeComparator* extractor_; + // The arithmetic analyzer. + arith::Analyzer* analyzer_; + /*! \brief Potential mappings on RHS for each variable on LHS */ + std::unordered_map lhs_feasible_vars_; +}; + +Optional GetAutoTensorizeMappingInfo(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); + // Step 2. Check if `desc_block` matches `block` + // Ignore the scope of buffers when comparing, since we can do cache_read/write + const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); + const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + AutoTensorizeComparator extractor(self->mod); + if (!extractor.VisitStmt(block->block, desc_info.desc_block->block)) { + return NullOpt; + } + Array mappings = AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer); + if (mappings.empty()) { + return NullOpt; + } + ObjectPtr ret = make_object(); + ret->mappings = std::move(mappings); + ret->lhs_buffer_map = std::move(extractor.lhs_buffer_map_); + ret->rhs_buffer_indices = std::move(extractor.rhs_buffer_indices_map_); + ret->lhs_iters = std::move(extractor.lhs_iters_); + ret->rhs_iters = std::move(extractor.rhs_iters_); + return AutoTensorizeMappingInfo(ret); +} + +TVM_REGISTER_NODE_TYPE(AutoTensorizeMappingInfoNode); + +TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") + .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { + return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); + }); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 58c502379a7a..d8ac40ef0586 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -333,12 +333,12 @@ bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { return true; } -template -bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { +template +bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) { - if (!(this->*cmp)(lhs[i], rhs[i])) return false; + if (!(static_cast(this)->*cmp)(lhs[i], rhs[i])) return false; } return true; } @@ -355,5 +355,125 @@ void TensorizeComparator::EmitError(const std::string& error_message) { error_messages_.push_back(error_message); } +/******** AutoTensorize Extractor ********/ + +bool AutoTensorizeComparator::VisitExprDefault_(const Object* op, const PrimExpr& other) { + return false; +} + +bool AutoTensorizeComparator::VisitStmtDefault_(const Object* op, const Stmt& other) { + return false; +} + +bool AutoTensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Check block equality. + // All iter vars and buffer regions including the order should match. + // When checking iter vars, DefEqual is used to remap variables. + if (!is_scope_block) { + if (!CompareArray(op->iter_vars, rhs->iter_vars, &AutoTensorizeComparator::CompareIterVar)) { + return false; + } + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { + return false; + } + if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, + &AutoTensorizeComparator::CompareBuffer)) { + return false; + } + for (const IterVar& block_iter : op->iter_vars) { + inner_iter_dom_map_.Set(block_iter->var, arith::IntSet::FromRange(block_iter->dom)); + } + } else { + auto collect_iter = [&](const BlockNode* op, std::vector& iters) -> bool { + for (const auto& iter : op->iter_vars) { + analyzer_.Bind(iter->var, iter->dom); + if (iter->iter_type == IterVarType::kDataPar || + iter->iter_type == IterVarType::kCommReduce) { + iters.push_back(iter); + } else { + return false; + } + } + return true; + }; + if (!collect_iter(op, lhs_iters_)) { + return false; + } + if (!collect_iter(rhs, rhs_iters_)) { + return false; + } + } + is_scope_block = false; + return VisitStmt(op->body, rhs->body); +} + +bool AutoTensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { + if (lhs.same_as(rhs)) return true; + auto it = rhs_buffer_map_.find(rhs); + bool equal; + if (it != rhs_buffer_map_.end()) { + equal = (*it).second.same_as(lhs); + } else { + // Remap both buffer itself and buffer data, skip buffer shape and scope + equal = DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype; + if (equal) { + rhs_buffer_map_[rhs] = lhs; + lhs_buffer_map_[lhs] = rhs; + } + } + return equal; +} + +bool AutoTensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); +} + +bool AutoTensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs); +} + +template +bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + auto it_lhs = lhs_buffer_indices_map_.find(lhs->buffer); + if (it_lhs == lhs_buffer_indices_map_.end()) { + if (rhs_buffer_indices_map_.find(rhs->buffer) != rhs_buffer_indices_map_.end()) { + return false; + } + std::vector lhs_indices; + for (const auto& index : lhs->indices) { + lhs_indices.push_back(analyzer_.Simplify(index)); + } + for (const auto& index : rhs->indices) { + if (!index.template as()) return false; + } + lhs_buffer_indices_map_[lhs->buffer] = lhs_indices; + rhs_buffer_indices_map_[rhs->buffer] = rhs->indices; + } else { + auto it_rhs = rhs_buffer_indices_map_.find(rhs->buffer); + if (it_rhs == rhs_buffer_indices_map_.end()) { + return false; + } + auto indices_check = [&](const Array& indices, + const Array& old_indices) -> bool { + if (indices.size() != old_indices.size()) { + return false; + } + for (size_t i = 0; i < indices.size(); ++i) { + if (!analyzer_.CanProveEqual(indices[i], old_indices[i])) { + return false; + } + } + return true; + }; + if (!indices_check(lhs->indices, it_lhs->second)) return false; + if (!indices_check(rhs->indices, it_rhs->second)) return false; + } + return true; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index 359677d8852f..394d82867393 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -90,8 +90,8 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool CompareAnnotationMap(const Map& lhs, const Map& rhs); template bool CompareBufferAccess(const T* lhs, const T* rhs); - template - bool CompareArray(const Array& lhs, const Array& rhs, F cmp); + template + bool CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp); bool CompareRange(const Range& lhs, const Range& rhs); bool CompareIterVar(const IterVar& lhs, const IterVar& rhs); void EmitError(const std::string& error_message); @@ -110,6 +110,54 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { std::unordered_map equal_map_; }; +/*! + * \brief IR comparator for auto tensorization. + * This comparator is used to extract correspondence between the IR of the workload (LHS) and the + * tensor intrin (RHS). Unlike `TensorizeComparator`, this comparator has relaxed requirements + * during comparison. It ignores the loop structure (number of loops and their extents) and buffer + * indices. It only requires the LHS and the RHS to have the same arithmetic operations and the same + * dtype. With such relaxed requirements, workloads that can only match the tensor intrin after + * certain transformations (e.g. im2col for conv2d) are allowed for auto tensorization. + */ +class AutoTensorizeComparator : public TensorizeComparator { + public: + explicit AutoTensorizeComparator(const IRModule& lhs_mod) + : TensorizeComparator(lhs_mod, /* assert_mode=*/false) {} + + private: + bool VisitExprDefault_(const Object* op, const PrimExpr& other) override; + bool VisitStmtDefault_(const Object* op, const Stmt& other) override; + + bool VisitStmt_(const BlockNode* op, const Stmt& other) override; + bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; + + bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; + + bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) override; + template + bool CompareBufferAccess(const T* lhs, const T* rhs); + + public: + // Additional information extracted from LHS (the workload) and RHS (the tensor intrin). + + /*! \brief Block iters in the LHS stmt. */ + std::vector lhs_iters_; + /*! \brief Block iters in the RHS stmt. */ + std::vector rhs_iters_; + /*! \brief The buffer and its access indices in the LHS stmt. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + lhs_buffer_indices_map_; + /*! \brief The buffer and its access indices in the RHS stmt. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + rhs_buffer_indices_map_; + /*! \brief Map from LHS buffer to RHS buffer */ + std::unordered_map lhs_buffer_map_; + + private: + /*! \brief The domain of the inner block iters. */ + Map inner_iter_dom_map_; +}; + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 19be0b8699ac..6761203a5a4d 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -16,14 +16,22 @@ # under the License. # pylint: disable=missing-docstring from typing import List - +import pytest import tvm +import tvm.testing +from tvm.tir.function import TensorIntrin from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc +from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f32_INTRIN from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule from tvm.tir.analysis import expr_deep_equal -from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping, TensorizeInfo +from tvm.tir.schedule.analysis import ( + get_auto_tensorize_mapping_info, + suggest_index_map, + get_tensorize_loop_mapping, + TensorizeInfo, +) from tvm.script import tir as T from tvm.tir.stmt_functor import pre_order_visit from tvm.meta_schedule.testing import te_workload @@ -252,9 +260,43 @@ def matmul_16x16x16xf16f16f16_desc( assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) +def check_index_map(workload, block_name, intrin_name, expected_index_map): + s = Schedule(workload) + block = s.get_block(block_name) + desc_func = TensorIntrin.get(intrin_name).desc + info = get_auto_tensorize_mapping_info(s, block, desc_func) + assert len(info.mappings) == 1 + assert IndexMap.from_func(expected_index_map).is_equivalent_to(info.mappings[0]) + + +def test_get_auto_tensorize_mapping_info_conv2d(): + conv2d = create_prim_func(te_workload.conv2d_nhwc_f16(4, 16, 16, 64, 64, 3, 1, 1)) + check_index_map( + conv2d, + "conv2d_nhwc", + WMMA_SYNC_16x16x16_f16f16f32_INTRIN, + lambda n, h, w, c, rh, rw, rc: (n * 256 + h * 16 + w, c, rh * 192 + rw * 64 + rc), + ) + + +def test_get_auto_tensorize_mapping_info_conv2d_unit_batch(): + conv2d = create_prim_func(te_workload.conv2d_nhwc_f16(1, 16, 16, 64, 64, 3, 1, 1)) + check_index_map( + conv2d, + "conv2d_nhwc", + WMMA_SYNC_16x16x16_f16f16f32_INTRIN, + # unit iter is not mapped + lambda n, h, w, c, rh, rw, rc: (n, h * 16 + w, c, rh * 192 + rw * 64 + rc), + ) + + +@pytest.mark.parametrize("b,m,n,k", [(1, 512, 512, 512), (16, 32, 32, 32)]) +def test_get_auto_tensorize_mapping_info_batch_matmul(b, m, n, k): + matmul = create_prim_func(te_workload.batch_matmul_nkkm_f16(b, m, n, k)) + check_index_map( + matmul, "Z", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, lambda b, m, n, k: (b, m, n, k) + ) + + if __name__ == "__main__": - test_suggest_index_map_simple() - test_suggest_index_map_bijective() - test_get_tensorize_loop_mapping_dense_vnni() - test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() - test_get_tensorize_loop_mapping_matmul_mma() + tvm.testing.main()