diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 5b44f79ad70a5..9fb212307bfca 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -411,8 +411,10 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); * \param buffer The buffer. * \param value The value to be stored. * \param indices The indices location to be stored. + * \param predicate A vector mask of int1 values indicating which lanes of a vector are to be + * stored. */ -void BufferStore(Buffer buffer, PrimExpr value, Array indices); +void BufferStore(Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate); /*! * \brief The prefetch hint for a buffer diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 39b32f563350a..b3673c4bb3569 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -630,11 +630,14 @@ class BufferLoadNode : public PrimExprNode { Buffer buffer; /*! \brief The indices location to be loaded. */ Array indices; + /*! \brief The predicate mask for loading values. */ + PrimExpr predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); v->Visit("buffer", &buffer); v->Visit("indices", &indices); + v->Visit("predicate", &predicate); v->Visit("span", &span); } @@ -647,6 +650,7 @@ class BufferLoadNode : public PrimExprNode { hash_reduce(dtype); hash_reduce(buffer); hash_reduce(indices); + hash_reduce(predicate); } static constexpr const char* _type_key = "tir.BufferLoad"; @@ -675,7 +679,8 @@ class BufferLoadNode : public PrimExprNode { */ class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, Span span = Span()); + TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, + PrimExpr predicate = PrimExpr(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 07cc9b5ad0d50..b60e7a80cfae3 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -231,11 +231,14 @@ class BufferStoreNode : public StmtNode { PrimExpr value; /*! \brief The indices location to be stored. */ Array indices; + /*! \brief The predicate mask for storing values. */ + PrimExpr predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("buffer", &buffer); v->Visit("value", &value); v->Visit("indices", &indices); + v->Visit("predicate", &predicate); v->Visit("span", &span); } @@ -248,6 +251,7 @@ class BufferStoreNode : public StmtNode { hash_reduce(buffer); hash_reduce(value); hash_reduce(indices); + hash_reduce(predicate); } static constexpr const char* _type_key = "tir.BufferStore"; @@ -261,7 +265,7 @@ class BufferStoreNode : public StmtNode { class BufferStore : public Stmt { public: TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices, - Span span = Span()); + PrimExpr predicate = PrimExpr(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index cb6e031667c53..756dbc4992f4d 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -57,6 +57,31 @@ def _updater(data): return _updater +def create_updater_16_to_17(): + """ + Create an update to upgrade json from v0.16 to v0.17 + + Returns + ------- + fupdater : function + The updater function + """ + + def _update_predicate_argument(item, nodes): + null_value_idx = 0 + null_value = nodes[null_value_idx] + assert str(null_value) == "{'type_key': ''}", f"Expected a null value but got {null_value}" + item["attrs"]["predicate"] = str(null_value_idx) + return item + + node_map = { + "tir.BufferLoad": _update_predicate_argument, + "tir.BufferStore": _update_predicate_argument, + } + + return create_updater(node_map, "0.16", "0.17") + + def create_updater_15_to_16(): """ Create an update to upgrade json from v0.15 to v0.16 @@ -316,5 +341,7 @@ def _from_version(data): data = create_updater({}, "0.14", "0.15")(data) if _from_version(data).startswith("0.15"): data = create_updater_15_to_16()(data) + if _from_version(data).startswith("0.16"): + data = create_updater_16_to_17()(data) return json.dumps(data, indent=2) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5a0a564a2ab59..1550ebc49efa2 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1265,6 +1265,7 @@ def buffer_store( buffer: Buffer, # pylint: disable=redefined-outer-name value: PrimExpr, indices: List[Union[PrimExpr, slice]], + predicate: Optional[PrimExpr] = None, ) -> None: """Buffer store node. @@ -1278,6 +1279,9 @@ def buffer_store( indices : List[Union[PrimExpr, slice]] The indices location to be stored. + + predicate : Optional[PrimExpr] + A vector mask of int1 values indicating which lanes of a vector are to be stored. """ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel @@ -1298,7 +1302,7 @@ def buffer_store( if isinstance(value, bool) and buffer.dtype == "bool": value = IntImm("bool", value) return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member - buffer, value, expr_indices + buffer, value, expr_indices, predicate ) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 679ae4e8adc08..600099bb0afba 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -462,6 +462,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: elif isinstance(res, str): # Ignore docstrings pass + elif isinstance(res, tvm.tir.stmt.BufferStore): + T.buffer_store(res.buffer, res.value, res.indices, res.predicate) else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index ec57ad7801caf..b6de8791dea1e 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -141,6 +141,57 @@ def vstore(self, begin, value): begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin return _ffi_api.BufferVStore(self, begin, value) # type: ignore + def load(self, indices, predicate=None): + """ + Load values at specified indices from buffer. + + Longhand notation that can be used for complex buffer load + expressions. For example, when the load involves predication. + + Parameters + ---------- + indices : List[PrimExpr] + The buffer indices to load values from. + + predicate : Optional[PrimExpr] + A vector mask of int1 values indicating which lanes of a vector are to be loaded. + + Returns + ------- + BufferLoad + A buffer load Expr. + """ + from .expr import BufferLoad # pylint: disable=import-outside-toplevel + + return BufferLoad(self, indices, predicate) + + def store(self, value, indices, predicate=None): + """ + Store given value at the specified indices in the buffer. + + Longhand notation that can be used for complex buffer store + statements. For example, when the store involves predication. + + Parameters + ---------- + value : PrimExpr + The value to be stored. + + indices : List[PrimExpr] + The buffer indices to store values to. + + predicate : Optional[PrimExpr] + A vector mask of int1 values indicating which lanes of a vector are to be stored. + + Returns + ------- + BufferStore + A buffer store Stmt. + """ + from .stmt import BufferStore # pylint: disable=import-outside-toplevel + + return BufferStore(self, value, indices, predicate) + def scope(self): """Return the storage scope associated with this buffer. Returns diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index fca501874d940..b9ea2c414d268 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1093,20 +1093,27 @@ class BufferLoad(PrimExprWithOp): The buffer to be loaded. indices : List[PrimExpr] - The buffer indices. + The buffer indices to load values from. span : Optional[Span] The location of this expression in the source code. + + predicate : Optional[PrimExpr] + A vector mask of int1 values indicating which lanes of a vector are to be loaded. """ buffer: Buffer indices: List[PrimExpr] def __init__( - self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = None + self, + buffer: Buffer, + indices: List[PrimExpr], + predicate: Optional[PrimExpr] = None, + span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.BufferLoad, buffer, indices, span # type: ignore + _ffi_api.BufferLoad, buffer, indices, predicate, span # type: ignore ) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 992c388e27bb9..6f8ce42cd9381 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -224,6 +224,9 @@ class BufferStore(Stmt): indices : List[PrimExpr] The indices location to be stored. + predicate : Optional[PrimExpr] + A vector mask of int1 values indicating which lanes of a vector are to be stored. + span : Optional[Span] The location of the stmt in the source code. """ @@ -231,6 +234,7 @@ class BufferStore(Stmt): buffer: Buffer value: PrimExpr indices: List[PrimExpr] + predicate: Optional[PrimExpr] span: Optional[Span] def __init__( @@ -238,10 +242,11 @@ def __init__( buffer: Buffer, value: PrimExpr, indices: List[PrimExpr], + predicate: Optional[PrimExpr] = None, span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.BufferStore, buffer, value, indices, span # type: ignore + _ffi_api.BufferStore, buffer, value, indices, predicate, span # type: ignore ) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7ea5032fa0cc9..3026f6e58f187 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 3ce5c15e6cd06..121d531fc9e06 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -524,7 +524,8 @@ Var EnvThread(String thread_tag, DataType dtype) { return var; } -void BufferStore(Buffer buffer, PrimExpr value, Array indices) { +void BufferStore(Buffer buffer, PrimExpr value, Array indices, + PrimExpr predicate = PrimExpr()) { runtime::DataType buffer_dtype = buffer->dtype; bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector(); bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector(); @@ -586,7 +587,7 @@ void BufferStore(Buffer buffer, PrimExpr value, Array indices) { } value = tvm::cast(lhs_dtype, value); } - AddToParent(tvm::tir::BufferStore(buffer, value, indices)); + AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate)); } void Prefetch(Buffer buffer, Array bounds) { diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 45a0dfd2aea4f..078d34fbba7b8 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -273,14 +273,33 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); - return AssignDoc(/*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], - /*rhs=*/d->AsDoc(store->value, p->Attr("value")), NullOpt); + ExprDoc value = d->AsDoc(store->value, p->Attr("value")); + + // Use .store(...) syntax when there is a predicate + if (store->predicate.defined()) { + ExprDoc indices = d->AsDoc(store->indices, p->Attr("indices")); + ExprDoc predicate = d->AsDoc(store->predicate, p->Attr("predicate")); + return ExprStmtDoc( + buffer->Attr("store")->Call({value, indices}, {"predicate"}, {predicate})); + } + + return AssignDoc( + /*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], + /*rhs=*/value, NullOpt); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); + + // Use .load(...) syntax when there is a predicate + if (load->predicate.defined()) { + ExprDoc indices = d->AsDoc(load->indices, p->Attr("indices")); + ExprDoc predicate = d->AsDoc(load->predicate, p->Attr("predicate")); + return buffer->Attr("load")->Call({indices}, {"predicate"}, {predicate}); + } + return buffer[BufferIndices(load->indices, p->Attr("indices"), d)]; }); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6566bb4291d83..5469598fdb4e5 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1664,9 +1664,9 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { } void CodeGenLLVM::BufferAccessHelper( - Buffer buffer, Array indices, DataType value_dtype, - std::function + Buffer buffer, Array indices, PrimExpr predicate, DataType value_dtype, + std::function make_instruction) { DataType buffer_element_dtype = buffer->dtype; @@ -1746,6 +1746,11 @@ void CodeGenLLVM::BufferAccessHelper( std::vector all_index_values = earlier_index_values; all_index_values.push_back(last_index_value); + llvm::Value* predicate_value = nullptr; + if (predicate.defined()) { + predicate_value = MakeValue(predicate); + } + TypedPointer buffer_ptr = value_dtype.is_scalable_vector() ? CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, @@ -1754,7 +1759,8 @@ void CodeGenLLVM::BufferAccessHelper( : CreateBufferPtr( MakeValue(buffer->data), buffer_element_dtype, all_index_values, value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); - auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); + auto instruction = + make_instruction(buffer_ptr, subelement_i, predicate_value, alignment, is_volatile); AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin); } } @@ -1764,11 +1770,17 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { std::vector loads; - auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, int alignment, - bool is_volatile) { + auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, + llvm::Value* predicate, int alignment, bool is_volatile) { #if TVM_LLVM_VERSION >= 110 - auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); + llvm::Instruction* load = nullptr; + if (predicate != NULL) { + load = builder_->CreateMaskedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), + predicate); + } else { + load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); + } #elif TVM_LLVM_VERSION >= 80 auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); @@ -1783,7 +1795,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { // Pass all indices into BufferAccessHelper. In CodeGenLLVM, // non-flat indices will result in an error in CreateBufferPtr, but // a subclass may override CreateBufferPtr. - BufferAccessHelper(op->buffer, op->indices, value_dtype, make_load); + BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_load); if (loads.size() == 1) { return loads[0]; @@ -1898,24 +1910,32 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { llvm::Value* value = MakeValue(op->value); - auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, int alignment, - bool is_volatile) { + auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, llvm::Value* predicate, + int alignment, bool is_volatile) { llvm::Value* to_store = value; + llvm::Instruction* store; + if (subelement_i != -1) { to_store = builder_->CreateExtractElement(value, subelement_i); } #if TVM_LLVM_VERSION >= 110 - return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), - is_volatile); + if (predicate != NULL) { + store = + builder_->CreateMaskedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), predicate); + } else { + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); + } #else - return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); #endif + return store; }; // Pass all indices into BufferAccessHelper. In CodeGenLLVM, // non-flat indices will result in an error in CreateBufferPtr, but // a subclass may override CreateBufferPtr. - BufferAccessHelper(op->buffer, op->indices, value_dtype, make_store); + BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_store); } void CodeGenLLVM::VisitStmt_(const ForNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 0f7aa847ecb88..88887dfd49981 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -330,6 +330,9 @@ class CodeGenLLVM : public ExprFunctor, * * \param indices The indices at which the buffer is being accessed. * + * \param predicate A vector mask of int1 values indicating which lanes of a vector are to be + * stored. + * * \param value_dtype The datatype to be read from (BufferLoad) or * written to (BufferStore) the buffer. * @@ -342,6 +345,8 @@ class CodeGenLLVM : public ExprFunctor, * stored/loaded. If -1, indicates that the entire type, * vector or scalar, should be written. * + * - predicate: The predicate mask of the buffer. + * * - alignment: The alignment to be used for the read/write. * * - is_volatile: Whether the read/write should be volatile. @@ -349,9 +354,9 @@ class CodeGenLLVM : public ExprFunctor, * - Should return the generated expression. */ void BufferAccessHelper( - Buffer buffer, Array indices, DataType value_dtype, - std::function + Buffer buffer, Array indices, PrimExpr predicate, DataType value_dtype, + std::function make_instruction); // Initialize target virtual void InitTarget(); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 03de68e326248..c7dbf3f5e042f 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -79,7 +79,7 @@ class BufferSubstituter : public StmtExprMutator { auto load = Downcast(StmtExprMutator::VisitExpr_(op)); auto it = buffer_map_.find(load->buffer.get()); if (it != buffer_map_.end()) { - return BufferLoad(it->second, load->indices, load->span); + return BufferLoad(it->second, load->indices, load->predicate, load->span); } return load; } @@ -88,7 +88,7 @@ class BufferSubstituter : public StmtExprMutator { auto store = Downcast(StmtExprMutator::VisitStmt_(op)); auto it = buffer_map_.find(store->buffer.get()); if (it != buffer_map_.end()) { - return BufferStore(it->second, store->value, store->indices, store->span); + return BufferStore(it->second, store->value, store->indices, store->predicate, store->span); } return store; } diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 4554038bc7702..40df8b65c2952 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -254,7 +254,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { Downcast(StmtExprMutator::VisitExpr_(buffer_load_node)); Buffer new_buffer = Subst(new_buffer_load->buffer.get()); if (!new_buffer.same_as(new_buffer_load->buffer)) { - return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span); + return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->predicate, + new_buffer_load->span); } return std::move(new_buffer_load); } @@ -293,7 +294,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { Buffer new_buffer = Subst(new_buffer_store->buffer.get()); if (!new_buffer.same_as(new_buffer_store->buffer)) { return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices, - new_buffer_store->span); + new_buffer_store->predicate, new_buffer_store->span); } return std::move(new_buffer_store); } diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 0c0d47571c4a2..ac1cf0ef11bbe 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -718,7 +718,8 @@ class MergeConstantsMutator : public StmtExprMutator { buffer->axis_separators, buffer->span}; old_to_new_read_buffers[buffer.as()] = new_buffer; - new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span)); + new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->predicate, + buffer_load->span)); break; } case 2: /* length */ { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 2cd2a698debef..b54be0796372a 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -772,7 +772,7 @@ void BufferLoadNode::LegalizeDType() { } } -BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { +BufferLoad::BufferLoad(Buffer buffer, Array indices, PrimExpr predicate, Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -781,14 +781,15 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->indices = std::move(indices); + node->predicate = std::move(predicate); node->span = std::move(span); node->LegalizeDType(); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferLoad") - .set_body_typed([](Buffer buffer, Array indices, Span span) { - return BufferLoad(buffer, indices, span); + .set_body_typed([](Buffer buffer, Array indices, PrimExpr predicate, Span span) { + return BufferLoad(buffer, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferLoadNode); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 089a1d31e7d0d..34b46583d5adf 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -127,7 +127,7 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { if (indices.same_as(op->indices)) { return GetRef(op); } else { - return BufferLoad(op->buffer, indices); + return BufferLoad(op->buffer, indices, op->predicate); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 4774471afcc0a..6bd4d97ce1c69 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -458,7 +458,8 @@ TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) TVM_REGISTER_NODE_TYPE(EvaluateNode); // BufferStore -BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, Span span) { +BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate, + Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -517,14 +518,14 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, node->buffer = std::move(buffer); node->value = std::move(value); node->indices = std::move(indices); + node->predicate = std::move(predicate); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferStore") - .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, Span span) { - return BufferStore(buffer, value, indices, span); - }); + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate, + Span span) { return BufferStore(buffer, value, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index 5f7b9b4156c31..95f7519a5b6ff 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -257,7 +257,9 @@ class RollingBufferInjector : public StmtExprMutator { indices.push_back(index); } } - Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->span); + ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not current supported in the inject rolling buffer pass."; + Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->predicate, op->span); // Then wrap the BufferStores in some Ifs to avoid recomputing elements for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) { auto iter_var{rolling_buffer_info.axis_iter_vars[i]}; @@ -293,7 +295,9 @@ class RollingBufferInjector : public StmtExprMutator { indices.push_back(index); } } - return BufferLoad(op->buffer, indices, op->span); + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not currently supported in inject rolling buffer pass."; + return BufferLoad(op->buffer, indices, op->predicate, op->span); } else { return expr; } diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 700587fe0e21e..3c2c6b67e653b 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -97,6 +97,8 @@ class MatchBufferLower : public StmtExprMutator { auto n = CopyOnWrite(op); n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); n->buffer = source->buffer; + ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not currently supported in lower match buffer pass."; return Stmt(n); } } @@ -113,6 +115,8 @@ class MatchBufferLower : public StmtExprMutator { const Buffer& buffer = (*it).first; const BufferRegion& source = (*it).second; Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not currently supported in lower match buffer pass."; return BufferLoad(source->buffer, indices); } } diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 619a9f0a9e8f0..885d5917136d1 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -67,6 +67,8 @@ class IntermediateStageRewriter { Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store); // Step 3: Create BufferLoad from the intermediate buffer + ICHECK(!store->predicate.defined()) << "Predicated buffer store is not currently supported in " + "manifest shared memory local stage pass."; BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); BufferStore new_buffer_store = Downcast(block->body); new_buffer_store.CopyOnWrite()->value = new_buffer_load; diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index bc606aa0b7ff0..3b418aac0cf57 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -213,7 +213,8 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { // A write whose destination is known to already contain the // values to be written is a no-op. // PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices); - PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices) == 0; + PrimExpr stores_existing_value = + store->value - BufferLoad(store->buffer, store->indices, store->predicate) == 0; if (touch_pattern_.has_value()) { Stmt context_arg = context_ ? GetRef(context_) : Stmt(store); stores_existing_value = diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 05b636f11403d..e8d89bfb5700d 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -196,7 +196,7 @@ class AllocateConstRewrite : public StmtExprMutator { op->buffer->elem_offset, it->second->name_hint, op->buffer->data_alignment, op->buffer->offset_factor, op->buffer->buffer_type); new_load_buf_[op->buffer->data.get()] = new_buffer; - return BufferLoad(new_buffer, op->indices); + return BufferLoad(new_buffer, op->indices, op->predicate); } return ExprMutator::VisitExpr_(op); } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 9c1244838173d..4ec4330a52b28 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -730,7 +730,7 @@ class ThreadScopePropagate : public StmtExprMutator { auto it = buf_remap_.find(op->buffer->data); if (it != buf_remap_.end()) { - return BufferLoad(it->second, op->indices, op->span); + return BufferLoad(it->second, op->indices, op->predicate, op->span); } else { return expr; } @@ -743,7 +743,7 @@ class ThreadScopePropagate : public StmtExprMutator { auto it = buf_remap_.find(op->buffer->data); if (it != buf_remap_.end()) { - return BufferStore(it->second, op->value, op->indices, op->span); + return BufferStore(it->second, op->value, op->indices, op->predicate, op->span); } else { return stmt; } @@ -938,8 +938,11 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "storage flatten pass."; return BufferLoad(e.remap->target, - remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + remap_indices(op->indices, e.remap->begins, e.remap->extents), + op->predicate, op->span); } else { return expr; } @@ -952,8 +955,11 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "storage flatten pass."; return BufferStore(e.remap->target, op->value, - remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + remap_indices(op->indices, e.remap->begins, e.remap->extents), + op->predicate, op->span); } else { return stmt; } @@ -1418,7 +1424,9 @@ class StorageFlattener : public StmtExprMutator { auto flattened_indices = e.buffer->ElemOffset(op->indices); - Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->span); + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "storage flatten pass."; + Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->predicate, op->span); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } @@ -1573,8 +1581,10 @@ class StorageFlattener : public StmtExprMutator { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "storage flatten pass."; auto flattened_indices = e.buffer->ElemOffset(op->indices); - PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->span); + PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->predicate, op->span); if (op->dtype == DataType::Bool()) { ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 5537c8a409a0a..ba5157d8bfc73 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -333,6 +333,8 @@ class ComputeLegalizer : public StmtExprMutator { ICHECK(MatchDType(value->dtype)); value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value); } + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "data type legalizer pass."; return BufferStore(new_buf, value, indices); } } @@ -404,6 +406,8 @@ class ComputeLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "data type legalizer pass."; return BufferLoad(new_buf, op->indices); } } @@ -565,6 +569,8 @@ class StorageLegalizer : public StmtExprMutator { if (MatchDType(op->value.dtype())) { ICHECK(new_buf->dtype.is_uint()); } + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "data type legalizer pass."; return BufferStore(new_buf, value, indices); } } @@ -598,6 +604,8 @@ class StorageLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "data type legalizer pass."; return BufferLoad(new_buf, op->indices); } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 3f5c070250448..e2990ffba7667 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -72,6 +72,126 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { return Broadcast(e, CreateNewLanes(is_scalable, lanes)); } +bool EnableBufferLevelPredication() { + transform::PassContext pass_ctx = transform::PassContext::Current(); + Optional enable_buffer_predication = + pass_ctx->GetConfig("tir.enable_buffer_level_predication"); + if (enable_buffer_predication.defined()) { + return enable_buffer_predication.value(); + } + + // Use buffer-level predication by default for AArch64 SVE targets + return arith::TargetHasSVE(); +} + +/*! + * \brief A pass that tries to rewrite buffer accesses (loads and stores) with a + * predicate expression where possible. + * + * \note For now we start with a minimalized case targeting block-level predicates + * produced by the split schedule primitive, with the potential for predicating + * more complex terms in the future if needed. + * + * \example + * Before: + * for i_0 in T.serial(4): + * for i_1 in T.vectorized(4): + * if i_0 * 4 + i_1 < 14: + * B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + * + * After: + * for i_0 in T.serial(4): + * predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14) + * A_load = T.meta_var(A.load([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate)) + * B.store(A_load, [T.Ramp(i_0 * 4, 1, 4)], predicate=predicate) + */ +class TryPredicateBufferAccesses : public StmtExprMutator { + public: + TryPredicateBufferAccesses() {} + + /*! + * \brief Run the pass to try to exact predicates. + * \param stmt - The statement containing buffer accesses (loads and stores) + * we want to attempt to predicate. + * \param condition - The conditional expression (block-level predicate) + * that we will try to remove. + * \return pair - Boolean value for success/failure, the rewritten + * stmt if successful. + */ + std::pair Run(Stmt stmt, PrimExpr condition) { + // Check that the condition provided is of the form a < b, for now. + if (!condition->IsInstance()) { + return {false, stmt}; + } + + LT lt = Downcast(condition); + + // Check the form of the vectorized condition, we're expecting + // Ramp(...) < Broadcast(...) + if (!lt->a->IsInstance() || !lt->b->IsInstance()) { + return {false, stmt}; + } + + base_ = Downcast(lt->a)->base; + limit_ = Downcast(lt->b)->value; + + // Now we can try to predicate + Stmt predicated_stmt = StmtExprMutator::operator()(std::move(stmt)); + if (num_accesses_analyzed_ > 0 && num_accesses_analyzed_ == num_accesses_rewritten_) { + return {true, predicated_stmt}; + } + return {false, stmt}; + } + + private: + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + return TryPredicateBufferAccess(load); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + return TryPredicateBufferAccess(store); + } + + template + AccessNode TryPredicateBufferAccess(AccessNode node) { + num_accesses_analyzed_ += 1; + + // Do not try to predicate non-vectorized accesses + Array indices = node->indices; + if (!indices.size() || !indices[0]->IsInstance()) { + return node; + } + Ramp ramp = Downcast(node->indices[0]); + + // The vectorized access pattern must match the base of the predicate + if (!tvm::StructuralEqual()(ramp->base, base_)) { + return node; + } + + DataType buf_predicate_dtype = + DataType(DataType::kInt, 1, ramp->dtype.get_lanes_or_vscale_factor(), + ramp->dtype.is_scalable_vector()); + Call lane_mask = Call(buf_predicate_dtype, builtin::get_active_lane_mask(), {base_, limit_}); + + num_accesses_rewritten_ += 1; + auto writer = node.CopyOnWrite(); + writer->predicate = lane_mask; + return node; + } + + /*! \brief The variable base expr of the predicate. */ + PrimExpr base_; + /*! \brief The limit of the predicate. The expr specifies the upper bound of the base's + * evaluated value. */ + PrimExpr limit_; + /*! \brief The number of buffer accesses in the stmt we will analyze. */ + size_t num_accesses_analyzed_ = 0; + /*! \brief The number of buffer accesses rewritten with predicates. */ + size_t num_accesses_rewritten_ = 0; +}; + // Rewrite vectorized allocation access // This is necessary for making each vector component containing its own workspace. // Originates from Halide's loop vectorizer @@ -555,14 +675,26 @@ class Vectorizer : public StmtMutator, public ExprFunctorcondition.dtype().is_scalable_or_fixed_length_vector()); PrimExpr condition = this->VisitExpr(op->condition); - if (condition.dtype().is_scalable_or_fixed_length_vector()) { - return Scalarize(GetRef(op)); - } Stmt then_case = this->VisitStmt(op->then_case); Optional else_case = NullOpt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } + + // Check if we can rewrite the condition with predicated buffers + if (EnableBufferLevelPredication() && condition.dtype().is_scalable_or_fixed_length_vector() && + !else_case.defined()) { + std::pair success_stmt_pair = + TryPredicateBufferAccesses().Run(then_case, condition); + bool can_remove_if_then_else = success_stmt_pair.first; + if (can_remove_if_then_else) { + return success_stmt_pair.second; + } + } + + if (condition.dtype().is_scalable_or_fixed_length_vector()) { + return Scalarize(GetRef(op)); + } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 452638beda0ab..c93dd2fed4fc5 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -700,5 +700,31 @@ def before(a: T.handle): assert "get.active.lane.mask" in ll +@pytest.mark.skipif( + llvm_version_major() < 11, + reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM" +) +def test_predicated_scalable_buffer(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())): + for i_1 in T.vectorized(4 * T.vscale()): + if i_0 * 4 * T.vscale() + i_1 < 14: + B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 + + with tvm.target.Target(target): + out = tvm.build(before) + + ll = out.get_source("ll") + assert "get.active.lane.mask" in ll + assert "llvm.masked.load" in ll + assert "llvm.masked.store" in ll + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index d4fa17bf8fa40..65381a0eb9ee8 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -348,5 +348,99 @@ def test_v0_16_ramp_broadcast_lanes(): assert graph.value.lanes == 12 +def test_v0_17_load_store_predicate(): + json_graph_v0_16 = { + "root": 1, + "nodes": [ + {"type_key": ""}, + { + "type_key": "tir.BufferStore", + "attrs": { + "buffer": "2", + "indices": "19", + "predicate": "0", + "span": "0", + "value": "13", + }, + }, + { + "type_key": "tir.Buffer", + "attrs": { + "axis_separators": "11", + "buffer_type": "1", + "data": "3", + "data_alignment": "64", + "dtype": "float32", + "elem_offset": "12", + "name": "4", + "offset_factor": "1", + "shape": "8", + "span": "0", + "strides": "10", + }, + }, + { + "type_key": "tir.Var", + "attrs": {"dtype": "handle", "name": "4", "span": "0", "type_annotation": "5"}, + }, + {"type_key": "runtime.String"}, + {"type_key": "PointerType", "attrs": {"element_type": "6", "storage_scope": "7"}}, + {"type_key": "PrimType", "attrs": {"dtype": "float32"}}, + {"type_key": "runtime.String", "repr_str": "global"}, + {"type_key": "Array", "data": [9]}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "8"}}, + {"type_key": "Array"}, + {"type_key": "Array"}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "0"}}, + { + "type_key": "tir.BufferLoad", + "attrs": { + "buffer": "2", + "dtype": "float32x4", + "indices": "14", + "predicate": "0", + "span": "0", + }, + }, + {"type_key": "Array", "data": [15]}, + { + "type_key": "tir.Ramp", + "attrs": { + "base": "16", + "dtype": "int32x4", + "lanes": "18", + "span": "0", + "stride": "17", + }, + }, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "0"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "1"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "4"}}, + {"type_key": "Array", "data": [20]}, + { + "type_key": "tir.Ramp", + "attrs": { + "base": "21", + "dtype": "int32x4", + "lanes": "23", + "span": "0", + "stride": "22", + }, + }, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "4"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "1"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "4"}}, + ], + "b64ndarrays": [], + "attrs": {"tvm_version": "0.16.0"}, + } + + expr = tvm.ir.load_json(json.dumps(json_graph_v0_16)) + buffer_store = expr + buffer_load = buffer_store.value + assert not buffer_store.predicate + assert not buffer_load.predicate + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index de5453eb5c44d..e96a546d6e30d 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -125,12 +125,15 @@ def main(A: T.Buffer((25,), "float32")): tvm.tir.transform.VectorizeLoop()(Module) -@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) -def test_vectorize_with_if(extent, target): +def test_vectorize_with_if(): + extent = 4 + target = simple_target + @I.ir_module class Before: @T.prim_func - def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") for i in T.vectorized(extent): if x < n: A[i] = A[i] + T.float32(1) @@ -141,7 +144,8 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): @I.ir_module class After: @T.prim_func - def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") if x < n: A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( T.float32(1), extent @@ -156,6 +160,43 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): tvm.ir.assert_structural_equal(mod, After) +def test_vectorize_if_scalable_extent(): + extent = T.vscale() * 4 + target = sve_target + + @I.ir_module + class Before: + @T.prim_func + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") + for i in T.vectorized(extent): + if x < n: + A[i] = A[i] + T.float32(1) + else: + if i < n: + A[i] = T.float32(2) + + @I.ir_module + class After: + @T.prim_func + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") + if x < n: + A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( + T.float32(1), extent + ) + else: + A.store( + T.Broadcast(T.float32(2), T.vscale() * 4), + [T.Ramp(0, 1, T.vscale() * 4)], + predicate=T.get_active_lane_mask("int1xvscalex4", 0, n), + ) + + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + def test_vectorize_with_if_cond_int64(): m = te.size_var("m", dtype="int64") A = te.placeholder((m,), name="A", dtype="float32") @@ -488,5 +529,170 @@ def main(A: T.Buffer((16,), "float32")): tvm.tir.transform.VectorizeLoop()(Mod) +def test_vectorize_and_predicate_all_buffer_loads_stores(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + load_a = T.meta_var( + A.load( + [T.Ramp(i_0 * 4, 1, 4)], predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14) + ) + ) + add_1 = T.meta_var(load_a + T.Broadcast(T.float32(1), 4)) + B.store( + add_1, + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_some_buffer_loads_stores(): + # Currently revert to scalarizing the block if not all accesses + # have been predicated, otherwise incorrect code is generated. + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0, i_1_s in T.grid(4, 4): + if i_0 * 4 + i_1_s < 14: + B[i_0 * 4 + i_1_s] = A[i_0] + T.float32(1) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_multiple_access_statements(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + A[i_0 * 4 + i_1] = 2.0 + B[i_0 * 4 + i_1] = 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + A.store( + T.Broadcast(T.float32(2), 4), + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + B.store( + T.Broadcast(T.float32(1), 4), + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + + before_mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_invalid_conditions(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 > 14: + A[i_0 * 4 + i_1] = 2.0 + if 14 < i_0 * 4 + i_1: + A[i_0 * 4 + i_1] = 2.0 + if i_0 * 4 + i_1 < i_0 * 4 + i_1: + A[i_0 * 4 + i_1] = 2.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + for i_1_s in range(4): + if i_0 * 4 + i_1_s > 14: + A[i_0 * 4 + i_1_s] = T.float32(2) + for i_1_s in range(4): + if 14 < i_0 * 4 + i_1_s: + A[i_0 * 4 + i_1_s] = T.float32(2) + for i_1_s in range(4): + if i_0 * 4 + i_1_s < i_0 * 4 + i_1_s: + A[i_0 * 4 + i_1_s] = T.float32(2) + + before_mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_with_explicitly_disabled_buffer_level_predication(): + # Since the target is has the SVe feature, buffer level predication is enabled + # by default. However, it has been explicitely disabled by the pass context + # option, so no buffer-level predicates should be added. + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0, i_1_s in T.grid(4, 4): + if i_0 * 4 + i_1_s < 14: + B[i_0 * 4 + i_1_s] = A[i_0 * 4 + i_1_s] + T.float32(1) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": False}): + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index c20784b4bf754..4636646b9216d 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -468,6 +468,20 @@ def test_ir_builder_tir_buffer_store_scalable_vec(): assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) +def test_ir_builder_tir_buffer_store_predicate(): + buffer_a = T.Buffer((30,), "float32") + value = T.broadcast(0.11, T.vscale() * 4) + index = T.ramp(0, 1, T.vscale() * 4) + predicate = T.broadcast(1, T.vscale() * 4) + + with IRBuilder() as ib: + T.buffer_store(buffer_a, value, [index], predicate) + + ir_actual = ib.get() + ir_expected = tir.BufferStore(buffer_a, value, [index], predicate) + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + def test_ir_builder_tir_prefetch(): with IRBuilder() as ib: buffer_a = T.Buffer((128, 128), "float32") diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index edc6da31636bf..13e6aec285a67 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -948,5 +948,83 @@ def func(): _assert_print(func, expected_output) +def test_predicated_load_store(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (256, 256), "float32") + T.func_attr({"global_symbol": "func"}) + a_load = T.meta_var(A.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4))) + A.store(a_load, [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(A.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4)), [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + """ + _assert_print(main, expected_output) + + +def test_predicated_buffer_load_store(): + a = tir.Var("a", "handle") + b = tir.Var("b", "handle") + buffer_map = { + a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + } + buffer_load = tir.BufferLoad( + buffer=buffer_map[b], indices=[0, tir.Ramp(0, 4, 4)], predicate=tir.Broadcast(0, 4) + ) + body = tir.BufferStore( + buffer=buffer_map[a], + value=buffer_load, + indices=[0, tir.Ramp(0, 2, 4)], + predicate=tir.Broadcast(0, 4), + ) + func = tir.PrimFunc( + params=[a, b], + ret_type=None, + buffer_map=buffer_map, + body=body, + ) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func(private=True) +def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(B.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4)), [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + """ + _assert_print(func, expected_output) + + +def test_predicated_scalable_load_store(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (256, 256), "float32") + T.func_attr({"global_symbol": "func"}) + mask = T.meta_var(T.get_active_lane_mask("int1xvscalex4", 0, 13)) + a_load = T.meta_var(A.load([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=mask)) + A.store(a_load, [0, T.Ramp(0, 2, T.vscale() * 4)], predicate=mask) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(\ +A.load([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=T.get_active_lane_mask("int1xvscalex4", 0, 13)), \ +[0, T.Ramp(0, 2, T.vscale() * 4)], predicate=T.get_active_lane_mask("int1xvscalex4", 0, 13)) + """ + _assert_print(main, expected_output) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 73bf200bb22a0..bcc318caf6f2c 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3352,6 +3352,18 @@ def func(a: T.handle): return func +def predicated_buffer_load_store(): + @T.prim_func + def func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (4,), "float32") + B = T.match_buffer(b, (8,), "float32") + for i_0 in range(4): + load_a = T.meta_var(A.load([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(1.0, 4))) + B.store(load_a, [T.Ramp(0, 2, 4)], predicate=T.Broadcast(1.0, 4)) + + return func + + def let_expression(): @T.prim_func def func(): @@ -4116,6 +4128,8 @@ def func(A: R.Object): buffer_axis_separator, buffer_ramp_access_as_slice_index, ramp_int64, + scalable_vectors, + predicated_buffer_load_store, let_expression, void_ptr, decl_buffer,