From c4812e9763102a7f3b3e9ac7f9a9226ad05254e8 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 1 May 2024 10:39:13 +0000 Subject: [PATCH 1/7] [SVE] Add support for representing and creating buffer-level predicates Representation -------------- This commit extends `BufferLoad` and `BufferStore` to accept a predicate mask argument indicating which lanes in a vectorized buffer load/store should be read/written. As a simple example, we can load all lanes: ``` tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(1, 8)) ``` Or disable loading all lanes: ``` tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(0, 8)) ``` In TVMScript, buffer loads and stores are currently displayed using a "short-hand" notation e.g. `A[0:4]`, but there was no clear path for extending this notation to support predicates. Therefore, a "long-hand" notation is introduced e.g. `A.load([T.Ramp(0, 1, 4)], predicate=...)`. The TVMScript printer falls back to the long-hand notation whenever predicates are specified. Creation -------- Buffer-level predication becomes more motivating when combined with the `tir.get_active_lane_mask` intrinsic. It can be used to mask off lanes when the vectorized axis is not divisible by the vector length. A detailed example and rationale can be found in the [RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0104-scalable-vectors-in-tir.md#predication). Predicated buffer load/stores are created in the `VectorizeLoop` pass via `TryPredicateBufferAccesses`. This pass aims to convert block-level predicates e.g. ``` for i_0 in T.serial(4): for i_1 in T.vectorized(4): if i_0 * 4 + i_1 < 14: B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 ``` to buffer-level predicates, e.g. ``` for i_0 in T.serial(4): predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14) A_load = T.meta_var(A.load([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate)) B.store(A_load, [T.Ramp(i_0 * 4, 1, 4)], predicate=predicate) ``` It takes a conservative approach for now, focussing only on expressions produced by the split scheduling primitive, but more complex expressions could be supported in the future. `TryPredicateBufferAccesses` can be explicitly enabled/disabled with the `tir.enable_buffer_level_predication` pass context option. By default it will be disabled, unless the target supports SVE, in which case it will be enabled by default. Co-authored-by: Elen Kalda Co-authored-by: Neil Hickey Change-Id: Idde259a7d7e4536f00ed3a1dafedd0a5d24a1593 --- include/tvm/script/ir_builder/tir/ir.h | 4 +- include/tvm/tir/expr.h | 7 +- include/tvm/tir/stmt.h | 6 +- python/tvm/ir/json_compact.py | 27 +++ python/tvm/script/ir_builder/tir/ir.py | 6 +- python/tvm/script/parser/tir/parser.py | 2 + python/tvm/tir/buffer.py | 51 +++++ python/tvm/tir/expr.py | 13 +- python/tvm/tir/stmt.py | 7 +- src/driver/driver_api.cc | 1 + src/script/ir_builder/tir/ir.cc | 5 +- src/script/printer/tir/buffer.cc | 23 +- src/target/llvm/codegen_llvm.cc | 50 ++-- src/target/llvm/codegen_llvm.h | 11 +- src/te/operation/create_primfunc.cc | 4 +- src/tir/analysis/device_constraint_utils.cc | 5 +- src/tir/contrib/ethosu/passes.cc | 3 +- src/tir/ir/expr.cc | 7 +- src/tir/ir/expr_functor.cc | 2 +- src/tir/ir/stmt.cc | 9 +- src/tir/transforms/inject_rolling_buffer.cc | 8 +- src/tir/transforms/lower_match_buffer.cc | 4 + .../manifest_shared_memory_local_stage.cc | 2 + src/tir/transforms/remove_no_op.cc | 3 +- .../remove_weight_layout_rewrite_block.cc | 2 +- src/tir/transforms/storage_flatten.cc | 22 +- .../transforms/unsupported_dtype_legalize.cc | 8 + src/tir/transforms/vectorize_loop.cc | 138 ++++++++++- .../codegen/test_target_codegen_aarch64.py | 26 +++ tests/python/relay/test_json_compact.py | 94 ++++++++ .../test_tir_transform_vectorize.py | 214 +++++++++++++++++- .../test_tvmscript_ir_builder_tir.py | 14 ++ .../tvmscript/test_tvmscript_printer_tir.py | 78 +++++++ .../tvmscript/test_tvmscript_roundtrip.py | 14 ++ 34 files changed, 810 insertions(+), 60 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 5b44f79ad70a..9fb212307bfc 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -411,8 +411,10 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); * \param buffer The buffer. * \param value The value to be stored. * \param indices The indices location to be stored. + * \param predicate A vector mask of int1 values indicating which lanes of a vector are to be + * stored. */ -void BufferStore(Buffer buffer, PrimExpr value, Array indices); +void BufferStore(Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate); /*! * \brief The prefetch hint for a buffer diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 39b32f563350..b3673c4bb356 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -630,11 +630,14 @@ class BufferLoadNode : public PrimExprNode { Buffer buffer; /*! \brief The indices location to be loaded. */ Array indices; + /*! \brief The predicate mask for loading values. */ + PrimExpr predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); v->Visit("buffer", &buffer); v->Visit("indices", &indices); + v->Visit("predicate", &predicate); v->Visit("span", &span); } @@ -647,6 +650,7 @@ class BufferLoadNode : public PrimExprNode { hash_reduce(dtype); hash_reduce(buffer); hash_reduce(indices); + hash_reduce(predicate); } static constexpr const char* _type_key = "tir.BufferLoad"; @@ -675,7 +679,8 @@ class BufferLoadNode : public PrimExprNode { */ class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, Span span = Span()); + TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, + PrimExpr predicate = PrimExpr(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 07cc9b5ad0d5..b60e7a80cfae 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -231,11 +231,14 @@ class BufferStoreNode : public StmtNode { PrimExpr value; /*! \brief The indices location to be stored. */ Array indices; + /*! \brief The predicate mask for storing values. */ + PrimExpr predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("buffer", &buffer); v->Visit("value", &value); v->Visit("indices", &indices); + v->Visit("predicate", &predicate); v->Visit("span", &span); } @@ -248,6 +251,7 @@ class BufferStoreNode : public StmtNode { hash_reduce(buffer); hash_reduce(value); hash_reduce(indices); + hash_reduce(predicate); } static constexpr const char* _type_key = "tir.BufferStore"; @@ -261,7 +265,7 @@ class BufferStoreNode : public StmtNode { class BufferStore : public Stmt { public: TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices, - Span span = Span()); + PrimExpr predicate = PrimExpr(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index cb6e031667c5..756dbc4992f4 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -57,6 +57,31 @@ def _updater(data): return _updater +def create_updater_16_to_17(): + """ + Create an update to upgrade json from v0.16 to v0.17 + + Returns + ------- + fupdater : function + The updater function + """ + + def _update_predicate_argument(item, nodes): + null_value_idx = 0 + null_value = nodes[null_value_idx] + assert str(null_value) == "{'type_key': ''}", f"Expected a null value but got {null_value}" + item["attrs"]["predicate"] = str(null_value_idx) + return item + + node_map = { + "tir.BufferLoad": _update_predicate_argument, + "tir.BufferStore": _update_predicate_argument, + } + + return create_updater(node_map, "0.16", "0.17") + + def create_updater_15_to_16(): """ Create an update to upgrade json from v0.15 to v0.16 @@ -316,5 +341,7 @@ def _from_version(data): data = create_updater({}, "0.14", "0.15")(data) if _from_version(data).startswith("0.15"): data = create_updater_15_to_16()(data) + if _from_version(data).startswith("0.16"): + data = create_updater_16_to_17()(data) return json.dumps(data, indent=2) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5a0a564a2ab5..1550ebc49efa 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1265,6 +1265,7 @@ def buffer_store( buffer: Buffer, # pylint: disable=redefined-outer-name value: PrimExpr, indices: List[Union[PrimExpr, slice]], + predicate: Optional[PrimExpr] = None, ) -> None: """Buffer store node. @@ -1278,6 +1279,9 @@ def buffer_store( indices : List[Union[PrimExpr, slice]] The indices location to be stored. + + predicate : Optional[PrimExpr] + A vector mask of int1 values indicating which lanes of a vector are to be stored. """ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel @@ -1298,7 +1302,7 @@ def buffer_store( if isinstance(value, bool) and buffer.dtype == "bool": value = IntImm("bool", value) return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member - buffer, value, expr_indices + buffer, value, expr_indices, predicate ) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 679ae4e8adc0..600099bb0afb 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -462,6 +462,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: elif isinstance(res, str): # Ignore docstrings pass + elif isinstance(res, tvm.tir.stmt.BufferStore): + T.buffer_store(res.buffer, res.value, res.indices, res.predicate) else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index ec57ad7801ca..b6de8791dea1 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -141,6 +141,57 @@ def vstore(self, begin, value): begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin return _ffi_api.BufferVStore(self, begin, value) # type: ignore + def load(self, indices, predicate=None): + """ + Load values at specified indices from buffer. + + Longhand notation that can be used for complex buffer load + expressions. For example, when the load involves predication. + + Parameters + ---------- + indices : List[PrimExpr] + The buffer indices to load values from. + + predicate : Optional[PrimExpr] + A vector mask of int1 values indicating which lanes of a vector are to be loaded. + + Returns + ------- + BufferLoad + A buffer load Expr. + """ + from .expr import BufferLoad # pylint: disable=import-outside-toplevel + + return BufferLoad(self, indices, predicate) + + def store(self, value, indices, predicate=None): + """ + Store given value at the specified indices in the buffer. + + Longhand notation that can be used for complex buffer store + statements. For example, when the store involves predication. + + Parameters + ---------- + value : PrimExpr + The value to be stored. + + indices : List[PrimExpr] + The buffer indices to store values to. + + predicate : Optional[PrimExpr] + A vector mask of int1 values indicating which lanes of a vector are to be stored. + + Returns + ------- + BufferStore + A buffer store Stmt. + """ + from .stmt import BufferStore # pylint: disable=import-outside-toplevel + + return BufferStore(self, value, indices, predicate) + def scope(self): """Return the storage scope associated with this buffer. Returns diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index fca501874d94..b9ea2c414d26 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1093,20 +1093,27 @@ class BufferLoad(PrimExprWithOp): The buffer to be loaded. indices : List[PrimExpr] - The buffer indices. + The buffer indices to load values from. span : Optional[Span] The location of this expression in the source code. + + predicate : Optional[PrimExpr] + A vector mask of int1 values indicating which lanes of a vector are to be loaded. """ buffer: Buffer indices: List[PrimExpr] def __init__( - self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = None + self, + buffer: Buffer, + indices: List[PrimExpr], + predicate: Optional[PrimExpr] = None, + span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.BufferLoad, buffer, indices, span # type: ignore + _ffi_api.BufferLoad, buffer, indices, predicate, span # type: ignore ) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 992c388e27bb..6f8ce42cd938 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -224,6 +224,9 @@ class BufferStore(Stmt): indices : List[PrimExpr] The indices location to be stored. + predicate : Optional[PrimExpr] + A vector mask of int1 values indicating which lanes of a vector are to be stored. + span : Optional[Span] The location of the stmt in the source code. """ @@ -231,6 +234,7 @@ class BufferStore(Stmt): buffer: Buffer value: PrimExpr indices: List[PrimExpr] + predicate: Optional[PrimExpr] span: Optional[Span] def __init__( @@ -238,10 +242,11 @@ def __init__( buffer: Buffer, value: PrimExpr, indices: List[PrimExpr], + predicate: Optional[PrimExpr] = None, span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.BufferStore, buffer, value, indices, span # type: ignore + _ffi_api.BufferStore, buffer, value, indices, predicate, span # type: ignore ) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7ea5032fa0cc..3026f6e58f18 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 3ce5c15e6cd0..121d531fc9e0 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -524,7 +524,8 @@ Var EnvThread(String thread_tag, DataType dtype) { return var; } -void BufferStore(Buffer buffer, PrimExpr value, Array indices) { +void BufferStore(Buffer buffer, PrimExpr value, Array indices, + PrimExpr predicate = PrimExpr()) { runtime::DataType buffer_dtype = buffer->dtype; bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector(); bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector(); @@ -586,7 +587,7 @@ void BufferStore(Buffer buffer, PrimExpr value, Array indices) { } value = tvm::cast(lhs_dtype, value); } - AddToParent(tvm::tir::BufferStore(buffer, value, indices)); + AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate)); } void Prefetch(Buffer buffer, Array bounds) { diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 45a0dfd2aea4..078d34fbba7b 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -273,14 +273,33 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); - return AssignDoc(/*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], - /*rhs=*/d->AsDoc(store->value, p->Attr("value")), NullOpt); + ExprDoc value = d->AsDoc(store->value, p->Attr("value")); + + // Use .store(...) syntax when there is a predicate + if (store->predicate.defined()) { + ExprDoc indices = d->AsDoc(store->indices, p->Attr("indices")); + ExprDoc predicate = d->AsDoc(store->predicate, p->Attr("predicate")); + return ExprStmtDoc( + buffer->Attr("store")->Call({value, indices}, {"predicate"}, {predicate})); + } + + return AssignDoc( + /*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], + /*rhs=*/value, NullOpt); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); + + // Use .load(...) syntax when there is a predicate + if (load->predicate.defined()) { + ExprDoc indices = d->AsDoc(load->indices, p->Attr("indices")); + ExprDoc predicate = d->AsDoc(load->predicate, p->Attr("predicate")); + return buffer->Attr("load")->Call({indices}, {"predicate"}, {predicate}); + } + return buffer[BufferIndices(load->indices, p->Attr("indices"), d)]; }); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6fc083d17ccf..562f7a8747b3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1668,9 +1668,9 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { } void CodeGenLLVM::BufferAccessHelper( - Buffer buffer, Array indices, DataType value_dtype, - std::function + Buffer buffer, Array indices, PrimExpr predicate, DataType value_dtype, + std::function make_instruction) { DataType buffer_element_dtype = buffer->dtype; @@ -1750,6 +1750,11 @@ void CodeGenLLVM::BufferAccessHelper( std::vector all_index_values = earlier_index_values; all_index_values.push_back(last_index_value); + llvm::Value* predicate_value = nullptr; + if (predicate.defined()) { + predicate_value = MakeValue(predicate); + } + TypedPointer buffer_ptr = value_dtype.is_scalable_vector() ? CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, @@ -1758,7 +1763,8 @@ void CodeGenLLVM::BufferAccessHelper( : CreateBufferPtr( MakeValue(buffer->data), buffer_element_dtype, all_index_values, value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); - auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); + auto instruction = + make_instruction(buffer_ptr, subelement_i, predicate_value, alignment, is_volatile); AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin); } } @@ -1768,11 +1774,17 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { std::vector loads; - auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, int alignment, - bool is_volatile) { + auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, + llvm::Value* predicate, int alignment, bool is_volatile) { #if TVM_LLVM_VERSION >= 110 - auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); + llvm::Instruction* load = nullptr; + if (predicate != NULL) { + load = builder_->CreateMaskedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), + predicate); + } else { + load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); + } #elif TVM_LLVM_VERSION >= 80 auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); @@ -1787,7 +1799,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { // Pass all indices into BufferAccessHelper. In CodeGenLLVM, // non-flat indices will result in an error in CreateBufferPtr, but // a subclass may override CreateBufferPtr. - BufferAccessHelper(op->buffer, op->indices, value_dtype, make_load); + BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_load); if (loads.size() == 1) { return loads[0]; @@ -1902,24 +1914,32 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { llvm::Value* value = MakeValue(op->value); - auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, int alignment, - bool is_volatile) { + auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, llvm::Value* predicate, + int alignment, bool is_volatile) { llvm::Value* to_store = value; + llvm::Instruction* store; + if (subelement_i != -1) { to_store = builder_->CreateExtractElement(value, subelement_i); } #if TVM_LLVM_VERSION >= 110 - return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), - is_volatile); + if (predicate != NULL) { + store = + builder_->CreateMaskedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), predicate); + } else { + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); + } #else - return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); #endif + return store; }; // Pass all indices into BufferAccessHelper. In CodeGenLLVM, // non-flat indices will result in an error in CreateBufferPtr, but // a subclass may override CreateBufferPtr. - BufferAccessHelper(op->buffer, op->indices, value_dtype, make_store); + BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_store); } void CodeGenLLVM::VisitStmt_(const ForNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 06b36cb183d3..832ed2fbfeb8 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -330,6 +330,9 @@ class CodeGenLLVM : public ExprFunctor, * * \param indices The indices at which the buffer is being accessed. * + * \param predicate A vector mask of int1 values indicating which lanes of a vector are to be + * stored. + * * \param value_dtype The datatype to be read from (BufferLoad) or * written to (BufferStore) the buffer. * @@ -342,6 +345,8 @@ class CodeGenLLVM : public ExprFunctor, * stored/loaded. If -1, indicates that the entire type, * vector or scalar, should be written. * + * - predicate: The predicate mask of the buffer. + * * - alignment: The alignment to be used for the read/write. * * - is_volatile: Whether the read/write should be volatile. @@ -349,9 +354,9 @@ class CodeGenLLVM : public ExprFunctor, * - Should return the generated expression. */ void BufferAccessHelper( - Buffer buffer, Array indices, DataType value_dtype, - std::function + Buffer buffer, Array indices, PrimExpr predicate, DataType value_dtype, + std::function make_instruction); // Initialize target virtual void InitTarget(); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 03de68e32624..c7dbf3f5e042 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -79,7 +79,7 @@ class BufferSubstituter : public StmtExprMutator { auto load = Downcast(StmtExprMutator::VisitExpr_(op)); auto it = buffer_map_.find(load->buffer.get()); if (it != buffer_map_.end()) { - return BufferLoad(it->second, load->indices, load->span); + return BufferLoad(it->second, load->indices, load->predicate, load->span); } return load; } @@ -88,7 +88,7 @@ class BufferSubstituter : public StmtExprMutator { auto store = Downcast(StmtExprMutator::VisitStmt_(op)); auto it = buffer_map_.find(store->buffer.get()); if (it != buffer_map_.end()) { - return BufferStore(it->second, store->value, store->indices, store->span); + return BufferStore(it->second, store->value, store->indices, store->predicate, store->span); } return store; } diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 4554038bc770..40df8b65c295 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -254,7 +254,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { Downcast(StmtExprMutator::VisitExpr_(buffer_load_node)); Buffer new_buffer = Subst(new_buffer_load->buffer.get()); if (!new_buffer.same_as(new_buffer_load->buffer)) { - return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span); + return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->predicate, + new_buffer_load->span); } return std::move(new_buffer_load); } @@ -293,7 +294,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { Buffer new_buffer = Subst(new_buffer_store->buffer.get()); if (!new_buffer.same_as(new_buffer_store->buffer)) { return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices, - new_buffer_store->span); + new_buffer_store->predicate, new_buffer_store->span); } return std::move(new_buffer_store); } diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 0c0d47571c4a..ac1cf0ef11bb 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -718,7 +718,8 @@ class MergeConstantsMutator : public StmtExprMutator { buffer->axis_separators, buffer->span}; old_to_new_read_buffers[buffer.as()] = new_buffer; - new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span)); + new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->predicate, + buffer_load->span)); break; } case 2: /* length */ { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 2cd2a698debe..b54be0796372 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -772,7 +772,7 @@ void BufferLoadNode::LegalizeDType() { } } -BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { +BufferLoad::BufferLoad(Buffer buffer, Array indices, PrimExpr predicate, Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -781,14 +781,15 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->indices = std::move(indices); + node->predicate = std::move(predicate); node->span = std::move(span); node->LegalizeDType(); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferLoad") - .set_body_typed([](Buffer buffer, Array indices, Span span) { - return BufferLoad(buffer, indices, span); + .set_body_typed([](Buffer buffer, Array indices, PrimExpr predicate, Span span) { + return BufferLoad(buffer, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferLoadNode); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 089a1d31e7d0..34b46583d5ad 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -127,7 +127,7 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { if (indices.same_as(op->indices)) { return GetRef(op); } else { - return BufferLoad(op->buffer, indices); + return BufferLoad(op->buffer, indices, op->predicate); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 4774471afcc0..6bd4d97ce1c6 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -458,7 +458,8 @@ TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) TVM_REGISTER_NODE_TYPE(EvaluateNode); // BufferStore -BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, Span span) { +BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate, + Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -517,14 +518,14 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, node->buffer = std::move(buffer); node->value = std::move(value); node->indices = std::move(indices); + node->predicate = std::move(predicate); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferStore") - .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, Span span) { - return BufferStore(buffer, value, indices, span); - }); + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate, + Span span) { return BufferStore(buffer, value, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index 5f7b9b4156c3..95f7519a5b6f 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -257,7 +257,9 @@ class RollingBufferInjector : public StmtExprMutator { indices.push_back(index); } } - Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->span); + ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not current supported in the inject rolling buffer pass."; + Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->predicate, op->span); // Then wrap the BufferStores in some Ifs to avoid recomputing elements for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) { auto iter_var{rolling_buffer_info.axis_iter_vars[i]}; @@ -293,7 +295,9 @@ class RollingBufferInjector : public StmtExprMutator { indices.push_back(index); } } - return BufferLoad(op->buffer, indices, op->span); + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not currently supported in inject rolling buffer pass."; + return BufferLoad(op->buffer, indices, op->predicate, op->span); } else { return expr; } diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 700587fe0e21..3c2c6b67e653 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -97,6 +97,8 @@ class MatchBufferLower : public StmtExprMutator { auto n = CopyOnWrite(op); n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); n->buffer = source->buffer; + ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not currently supported in lower match buffer pass."; return Stmt(n); } } @@ -113,6 +115,8 @@ class MatchBufferLower : public StmtExprMutator { const Buffer& buffer = (*it).first; const BufferRegion& source = (*it).second; Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not currently supported in lower match buffer pass."; return BufferLoad(source->buffer, indices); } } diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 619a9f0a9e8f..885d5917136d 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -67,6 +67,8 @@ class IntermediateStageRewriter { Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store); // Step 3: Create BufferLoad from the intermediate buffer + ICHECK(!store->predicate.defined()) << "Predicated buffer store is not currently supported in " + "manifest shared memory local stage pass."; BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); BufferStore new_buffer_store = Downcast(block->body); new_buffer_store.CopyOnWrite()->value = new_buffer_load; diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index bc606aa0b7ff..3b418aac0cf5 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -213,7 +213,8 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { // A write whose destination is known to already contain the // values to be written is a no-op. // PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices); - PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices) == 0; + PrimExpr stores_existing_value = + store->value - BufferLoad(store->buffer, store->indices, store->predicate) == 0; if (touch_pattern_.has_value()) { Stmt context_arg = context_ ? GetRef(context_) : Stmt(store); stores_existing_value = diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 05b636f11403..e8d89bfb5700 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -196,7 +196,7 @@ class AllocateConstRewrite : public StmtExprMutator { op->buffer->elem_offset, it->second->name_hint, op->buffer->data_alignment, op->buffer->offset_factor, op->buffer->buffer_type); new_load_buf_[op->buffer->data.get()] = new_buffer; - return BufferLoad(new_buffer, op->indices); + return BufferLoad(new_buffer, op->indices, op->predicate); } return ExprMutator::VisitExpr_(op); } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index c51dfd7913e4..06554f5f1dd1 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -730,7 +730,7 @@ class ThreadScopePropagate : public StmtExprMutator { auto it = buf_remap_.find(op->buffer->data); if (it != buf_remap_.end()) { - return BufferLoad(it->second, op->indices, op->span); + return BufferLoad(it->second, op->indices, op->predicate, op->span); } else { return expr; } @@ -743,7 +743,7 @@ class ThreadScopePropagate : public StmtExprMutator { auto it = buf_remap_.find(op->buffer->data); if (it != buf_remap_.end()) { - return BufferStore(it->second, op->value, op->indices, op->span); + return BufferStore(it->second, op->value, op->indices, op->predicate, op->span); } else { return stmt; } @@ -938,8 +938,11 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "storage flatten pass."; return BufferLoad(e.remap->target, - remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + remap_indices(op->indices, e.remap->begins, e.remap->extents), + op->predicate, op->span); } else { return expr; } @@ -952,8 +955,11 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "storage flatten pass."; return BufferStore(e.remap->target, op->value, - remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + remap_indices(op->indices, e.remap->begins, e.remap->extents), + op->predicate, op->span); } else { return stmt; } @@ -1418,7 +1424,9 @@ class StorageFlattener : public StmtExprMutator { auto flattened_indices = e.buffer->ElemOffset(op->indices); - Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->span); + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "storage flatten pass."; + Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->predicate, op->span); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } @@ -1573,8 +1581,10 @@ class StorageFlattener : public StmtExprMutator { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "storage flatten pass."; auto flattened_indices = e.buffer->ElemOffset(op->indices); - PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->span); + PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->predicate, op->span); if (op->dtype == DataType::Bool()) { ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 5a14beb6dc4c..c75ecf77e708 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -330,6 +330,8 @@ class ComputeLegalizer : public StmtExprMutator { ICHECK(MatchDType(value->dtype)); value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value); } + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "data type legalizer pass."; return BufferStore(new_buf, value, indices); } } @@ -401,6 +403,8 @@ class ComputeLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "data type legalizer pass."; return BufferLoad(new_buf, op->indices); } } @@ -562,6 +566,8 @@ class StorageLegalizer : public StmtExprMutator { if (MatchDType(op->value.dtype())) { ICHECK(new_buf->dtype.is_uint()); } + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "data type legalizer pass."; return BufferStore(new_buf, value, indices); } } @@ -595,6 +601,8 @@ class StorageLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "data type legalizer pass."; return BufferLoad(new_buf, op->indices); } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index c4dde01b8f81..182369ba03fa 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -72,6 +72,126 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { return Broadcast(e, CreateNewLanes(is_scalable, lanes)); } +bool EnableBufferLevelPredication() { + transform::PassContext pass_ctx = transform::PassContext::Current(); + Optional enable_buffer_predication = + pass_ctx->GetConfig("tir.enable_buffer_level_predication"); + if (enable_buffer_predication.defined()) { + return enable_buffer_predication.value(); + } + + // Use buffer-level predication by default for AArch64 SVE targets + return arith::TargetHasSVE(); +} + +/*! + * \brief A pass that tries to rewrite buffer accesses (loads and stores) with a + * predicate expression where possible. + * + * \note For now we start with a minimalized case targeting block-level predicates + * produced by the split schedule primitive, with the potential for predicating + * more complex terms in the future if needed. + * + * \example + * Before: + * for i_0 in T.serial(4): + * for i_1 in T.vectorized(4): + * if i_0 * 4 + i_1 < 14: + * B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + * + * After: + * for i_0 in T.serial(4): + * predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14) + * A_load = T.meta_var(A.load([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate)) + * B.store(A_load, [T.Ramp(i_0 * 4, 1, 4)], predicate=predicate) + */ +class TryPredicateBufferAccesses : public StmtExprMutator { + public: + TryPredicateBufferAccesses() {} + + /*! + * \brief Run the pass to try to exact predicates. + * \param stmt - The statement containing buffer accesses (loads and stores) + * we want to attempt to predicate. + * \param condition - The conditional expression (block-level predicate) + * that we will try to remove. + * \return pair - Boolean value for success/failure, the rewritten + * stmt if successful. + */ + std::pair Run(Stmt stmt, PrimExpr condition) { + // Check that the condition provided is of the form a < b, for now. + if (!condition->IsInstance()) { + return {false, stmt}; + } + + LT lt = Downcast(condition); + + // Check the form of the vectorized condition, we're expecting + // Ramp(...) < Broadcast(...) + if (!lt->a->IsInstance() || !lt->b->IsInstance()) { + return {false, stmt}; + } + + base_ = Downcast(lt->a)->base; + limit_ = Downcast(lt->b)->value; + + // Now we can try to predicate + Stmt predicated_stmt = StmtExprMutator::operator()(std::move(stmt)); + if (num_accesses_analyzed_ > 0 && num_accesses_analyzed_ == num_accesses_rewritten_) { + return {true, predicated_stmt}; + } + return {false, stmt}; + } + + private: + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + return TryPredicateBufferAccess(load); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + return TryPredicateBufferAccess(store); + } + + template + AccessNode TryPredicateBufferAccess(AccessNode node) { + num_accesses_analyzed_ += 1; + + // Do not try to predicate non-vectorized accesses + Array indices = node->indices; + if (!indices.size() || !indices[0]->IsInstance()) { + return node; + } + Ramp ramp = Downcast(node->indices[0]); + + // The vectorized access pattern must match the base of the predicate + if (!tvm::StructuralEqual()(ramp->base, base_)) { + return node; + } + + DataType buf_predicate_dtype = + DataType(DataType::kInt, 1, ramp->dtype.get_lanes_or_vscale_factor(), + ramp->dtype.is_scalable_vector()); + Call lane_mask = Call(buf_predicate_dtype, builtin::get_active_lane_mask(), {base_, limit_}); + + num_accesses_rewritten_ += 1; + auto writer = node.CopyOnWrite(); + writer->predicate = lane_mask; + return node; + } + + /*! \brief The variable base expr of the predicate. */ + PrimExpr base_; + /*! \brief The limit of the predicate. The expr specifies the upper bound of the base's + * evaluated value. */ + PrimExpr limit_; + /*! \brief The number of buffer accesses in the stmt we will analyze. */ + size_t num_accesses_analyzed_ = 0; + /*! \brief The number of buffer accesses rewritten with predicates. */ + size_t num_accesses_rewritten_ = 0; +}; + // Rewrite vectorized allocation access // This is necessary for making each vector component containing its own workspace. // Originates from Halide's loop vectorizer @@ -555,14 +675,26 @@ class Vectorizer : public StmtMutator, public ExprFunctorcondition.dtype().is_scalable_or_fixed_length_vector()); PrimExpr condition = this->VisitExpr(op->condition); - if (condition.dtype().is_scalable_or_fixed_length_vector()) { - return Scalarize(GetRef(op)); - } Stmt then_case = this->VisitStmt(op->then_case); Optional else_case = NullOpt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } + + // Check if we can rewrite the condition with predicated buffers + if (EnableBufferLevelPredication() && condition.dtype().is_scalable_or_fixed_length_vector() && + !else_case.defined()) { + std::pair success_stmt_pair = + TryPredicateBufferAccesses().Run(then_case, condition); + bool can_remove_if_then_else = success_stmt_pair.first; + if (can_remove_if_then_else) { + return success_stmt_pair.second; + } + } + + if (condition.dtype().is_scalable_or_fixed_length_vector()) { + return Scalarize(GetRef(op)); + } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index f73d96e7c916..41449aa233d4 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -780,5 +780,31 @@ def before(a: T.handle): assert "get.active.lane.mask" in ll +@pytest.mark.skipif( + llvm_version_major() < 11, + reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM" +) +def test_predicated_scalable_buffer(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())): + for i_1 in T.vectorized(4 * T.vscale()): + if i_0 * 4 * T.vscale() + i_1 < 14: + B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 + + with tvm.target.Target(target): + out = tvm.build(before) + + ll = out.get_source("ll") + assert "get.active.lane.mask" in ll + assert "llvm.masked.load" in ll + assert "llvm.masked.store" in ll + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index d4fa17bf8fa4..65381a0eb9ee 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -348,5 +348,99 @@ def test_v0_16_ramp_broadcast_lanes(): assert graph.value.lanes == 12 +def test_v0_17_load_store_predicate(): + json_graph_v0_16 = { + "root": 1, + "nodes": [ + {"type_key": ""}, + { + "type_key": "tir.BufferStore", + "attrs": { + "buffer": "2", + "indices": "19", + "predicate": "0", + "span": "0", + "value": "13", + }, + }, + { + "type_key": "tir.Buffer", + "attrs": { + "axis_separators": "11", + "buffer_type": "1", + "data": "3", + "data_alignment": "64", + "dtype": "float32", + "elem_offset": "12", + "name": "4", + "offset_factor": "1", + "shape": "8", + "span": "0", + "strides": "10", + }, + }, + { + "type_key": "tir.Var", + "attrs": {"dtype": "handle", "name": "4", "span": "0", "type_annotation": "5"}, + }, + {"type_key": "runtime.String"}, + {"type_key": "PointerType", "attrs": {"element_type": "6", "storage_scope": "7"}}, + {"type_key": "PrimType", "attrs": {"dtype": "float32"}}, + {"type_key": "runtime.String", "repr_str": "global"}, + {"type_key": "Array", "data": [9]}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "8"}}, + {"type_key": "Array"}, + {"type_key": "Array"}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "0"}}, + { + "type_key": "tir.BufferLoad", + "attrs": { + "buffer": "2", + "dtype": "float32x4", + "indices": "14", + "predicate": "0", + "span": "0", + }, + }, + {"type_key": "Array", "data": [15]}, + { + "type_key": "tir.Ramp", + "attrs": { + "base": "16", + "dtype": "int32x4", + "lanes": "18", + "span": "0", + "stride": "17", + }, + }, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "0"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "1"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "4"}}, + {"type_key": "Array", "data": [20]}, + { + "type_key": "tir.Ramp", + "attrs": { + "base": "21", + "dtype": "int32x4", + "lanes": "23", + "span": "0", + "stride": "22", + }, + }, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "4"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "1"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "4"}}, + ], + "b64ndarrays": [], + "attrs": {"tvm_version": "0.16.0"}, + } + + expr = tvm.ir.load_json(json.dumps(json_graph_v0_16)) + buffer_store = expr + buffer_load = buffer_store.value + assert not buffer_store.predicate + assert not buffer_load.predicate + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index de5453eb5c44..e96a546d6e30 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -125,12 +125,15 @@ def main(A: T.Buffer((25,), "float32")): tvm.tir.transform.VectorizeLoop()(Module) -@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) -def test_vectorize_with_if(extent, target): +def test_vectorize_with_if(): + extent = 4 + target = simple_target + @I.ir_module class Before: @T.prim_func - def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") for i in T.vectorized(extent): if x < n: A[i] = A[i] + T.float32(1) @@ -141,7 +144,8 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): @I.ir_module class After: @T.prim_func - def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") if x < n: A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( T.float32(1), extent @@ -156,6 +160,43 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): tvm.ir.assert_structural_equal(mod, After) +def test_vectorize_if_scalable_extent(): + extent = T.vscale() * 4 + target = sve_target + + @I.ir_module + class Before: + @T.prim_func + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") + for i in T.vectorized(extent): + if x < n: + A[i] = A[i] + T.float32(1) + else: + if i < n: + A[i] = T.float32(2) + + @I.ir_module + class After: + @T.prim_func + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") + if x < n: + A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( + T.float32(1), extent + ) + else: + A.store( + T.Broadcast(T.float32(2), T.vscale() * 4), + [T.Ramp(0, 1, T.vscale() * 4)], + predicate=T.get_active_lane_mask("int1xvscalex4", 0, n), + ) + + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + def test_vectorize_with_if_cond_int64(): m = te.size_var("m", dtype="int64") A = te.placeholder((m,), name="A", dtype="float32") @@ -488,5 +529,170 @@ def main(A: T.Buffer((16,), "float32")): tvm.tir.transform.VectorizeLoop()(Mod) +def test_vectorize_and_predicate_all_buffer_loads_stores(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + load_a = T.meta_var( + A.load( + [T.Ramp(i_0 * 4, 1, 4)], predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14) + ) + ) + add_1 = T.meta_var(load_a + T.Broadcast(T.float32(1), 4)) + B.store( + add_1, + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_some_buffer_loads_stores(): + # Currently revert to scalarizing the block if not all accesses + # have been predicated, otherwise incorrect code is generated. + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0, i_1_s in T.grid(4, 4): + if i_0 * 4 + i_1_s < 14: + B[i_0 * 4 + i_1_s] = A[i_0] + T.float32(1) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_multiple_access_statements(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + A[i_0 * 4 + i_1] = 2.0 + B[i_0 * 4 + i_1] = 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + A.store( + T.Broadcast(T.float32(2), 4), + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + B.store( + T.Broadcast(T.float32(1), 4), + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + + before_mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_invalid_conditions(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 > 14: + A[i_0 * 4 + i_1] = 2.0 + if 14 < i_0 * 4 + i_1: + A[i_0 * 4 + i_1] = 2.0 + if i_0 * 4 + i_1 < i_0 * 4 + i_1: + A[i_0 * 4 + i_1] = 2.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + for i_1_s in range(4): + if i_0 * 4 + i_1_s > 14: + A[i_0 * 4 + i_1_s] = T.float32(2) + for i_1_s in range(4): + if 14 < i_0 * 4 + i_1_s: + A[i_0 * 4 + i_1_s] = T.float32(2) + for i_1_s in range(4): + if i_0 * 4 + i_1_s < i_0 * 4 + i_1_s: + A[i_0 * 4 + i_1_s] = T.float32(2) + + before_mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_with_explicitly_disabled_buffer_level_predication(): + # Since the target is has the SVe feature, buffer level predication is enabled + # by default. However, it has been explicitely disabled by the pass context + # option, so no buffer-level predicates should be added. + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0, i_1_s in T.grid(4, 4): + if i_0 * 4 + i_1_s < 14: + B[i_0 * 4 + i_1_s] = A[i_0 * 4 + i_1_s] + T.float32(1) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": False}): + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index c20784b4bf75..4636646b9216 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -468,6 +468,20 @@ def test_ir_builder_tir_buffer_store_scalable_vec(): assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) +def test_ir_builder_tir_buffer_store_predicate(): + buffer_a = T.Buffer((30,), "float32") + value = T.broadcast(0.11, T.vscale() * 4) + index = T.ramp(0, 1, T.vscale() * 4) + predicate = T.broadcast(1, T.vscale() * 4) + + with IRBuilder() as ib: + T.buffer_store(buffer_a, value, [index], predicate) + + ir_actual = ib.get() + ir_expected = tir.BufferStore(buffer_a, value, [index], predicate) + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + def test_ir_builder_tir_prefetch(): with IRBuilder() as ib: buffer_a = T.Buffer((128, 128), "float32") diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index edc6da31636b..13e6aec285a6 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -948,5 +948,83 @@ def func(): _assert_print(func, expected_output) +def test_predicated_load_store(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (256, 256), "float32") + T.func_attr({"global_symbol": "func"}) + a_load = T.meta_var(A.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4))) + A.store(a_load, [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(A.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4)), [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + """ + _assert_print(main, expected_output) + + +def test_predicated_buffer_load_store(): + a = tir.Var("a", "handle") + b = tir.Var("b", "handle") + buffer_map = { + a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + } + buffer_load = tir.BufferLoad( + buffer=buffer_map[b], indices=[0, tir.Ramp(0, 4, 4)], predicate=tir.Broadcast(0, 4) + ) + body = tir.BufferStore( + buffer=buffer_map[a], + value=buffer_load, + indices=[0, tir.Ramp(0, 2, 4)], + predicate=tir.Broadcast(0, 4), + ) + func = tir.PrimFunc( + params=[a, b], + ret_type=None, + buffer_map=buffer_map, + body=body, + ) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func(private=True) +def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(B.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4)), [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + """ + _assert_print(func, expected_output) + + +def test_predicated_scalable_load_store(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (256, 256), "float32") + T.func_attr({"global_symbol": "func"}) + mask = T.meta_var(T.get_active_lane_mask("int1xvscalex4", 0, 13)) + a_load = T.meta_var(A.load([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=mask)) + A.store(a_load, [0, T.Ramp(0, 2, T.vscale() * 4)], predicate=mask) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(\ +A.load([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=T.get_active_lane_mask("int1xvscalex4", 0, 13)), \ +[0, T.Ramp(0, 2, T.vscale() * 4)], predicate=T.get_active_lane_mask("int1xvscalex4", 0, 13)) + """ + _assert_print(main, expected_output) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 73bf200bb22a..bcc318caf6f2 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3352,6 +3352,18 @@ def func(a: T.handle): return func +def predicated_buffer_load_store(): + @T.prim_func + def func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (4,), "float32") + B = T.match_buffer(b, (8,), "float32") + for i_0 in range(4): + load_a = T.meta_var(A.load([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(1.0, 4))) + B.store(load_a, [T.Ramp(0, 2, 4)], predicate=T.Broadcast(1.0, 4)) + + return func + + def let_expression(): @T.prim_func def func(): @@ -4116,6 +4128,8 @@ def func(A: R.Object): buffer_axis_separator, buffer_ramp_access_as_slice_index, ramp_int64, + scalable_vectors, + predicated_buffer_load_store, let_expression, void_ptr, decl_buffer, From f274dc412f60c8a30e73b12acf273338e6cc3068 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 3 May 2024 14:40:58 +0000 Subject: [PATCH 2/7] Fix lint and correct test config option name Change-Id: I864475c3d03e9b426ce5ef987989216d57f3e019 --- tests/python/codegen/test_target_codegen_aarch64.py | 2 +- tests/python/tir-transform/test_tir_transform_vectorize.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 41449aa233d4..75e31365f432 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -782,7 +782,7 @@ def before(a: T.handle): @pytest.mark.skipif( llvm_version_major() < 11, - reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM" + reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", ) def test_predicated_scalable_buffer(): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index e96a546d6e30..44b886eedd6a 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -622,7 +622,7 @@ def expected(a: T.handle, b: T.handle): ) before_mod = tvm.IRModule.from_expr(before) - with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] tvm.ir.assert_structural_equal(after, expected) From 3cdbd1229dc0ba48ce3e79c32ec9ffa1b8d50e7a Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 14 May 2024 12:31:29 +0000 Subject: [PATCH 3/7] Address review comments This includes: * Taking into account possibility of target being overridden in the vectorize pass. * Predicate PrimExpr -> Optional * Checking that predicate is not used for any target that doesn't support it. * Use vload/vstore API as opposed to load/store * int1 mask -> uint1 mask for boolean representation. This is converted to int1 in the LLVM backend. Change-Id: I4da0705352e321f6be6333a5bb777caa6a6ca9ef --- include/tvm/script/ir_builder/tir/ir.h | 7 +- include/tvm/tir/buffer.h | 10 +- include/tvm/tir/expr.h | 4 +- include/tvm/tir/stmt.h | 4 +- python/tvm/script/ir_builder/tir/ir.py | 4 +- python/tvm/tir/buffer.py | 69 +++--------- python/tvm/tir/expr.py | 4 +- python/tvm/tir/stmt.py | 4 +- src/arith/analyzer.cc | 5 +- src/arith/const_int_bound.cc | 2 +- src/arith/scalable_expression.cc | 3 +- src/arith/scalable_expression.h | 4 +- src/script/ir_builder/tir/ir.cc | 2 +- src/script/printer/tir/buffer.cc | 4 +- src/target/llvm/codegen_llvm.cc | 34 ++++-- src/target/llvm/codegen_llvm.h | 7 +- src/target/source/codegen_c.cc | 2 + src/target/source/codegen_webgpu.cc | 3 + src/tir/ir/buffer.cc | 31 +++--- src/tir/ir/expr.cc | 15 ++- src/tir/ir/stmt.cc | 43 +++++--- src/tir/transforms/vectorize_loop.cc | 52 ++++++--- tests/python/codegen/test_target_codegen.py | 56 ++++++++++ .../codegen/test_target_codegen_aarch64.py | 2 +- .../codegen/test_target_codegen_llvm.py | 29 +++++ tests/python/tir-base/test_tir_nodes.py | 46 ++++++++ .../test_tir_transform_vectorize.py | 101 +++++++++++++++--- .../tvmscript/test_tvmscript_printer_tir.py | 43 +++++--- .../tvmscript/test_tvmscript_roundtrip.py | 6 +- 29 files changed, 431 insertions(+), 165 deletions(-) create mode 100644 tests/python/codegen/test_target_codegen.py diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 9fb212307bfc..380c2fcce25d 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -411,10 +411,11 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); * \param buffer The buffer. * \param value The value to be stored. * \param indices The indices location to be stored. - * \param predicate A vector mask of int1 values indicating which lanes of a vector are to be - * stored. + * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be + * stored. The number lanes of the mask must be equal to the number of lanes in value. */ -void BufferStore(Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate); +void BufferStore(Buffer buffer, PrimExpr value, Array indices, + Optional predicate); /*! * \brief The prefetch hint for a buffer diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index b2736a30e4bb..8719476af98f 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -209,14 +209,20 @@ class Buffer : public ObjectRef { * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index * \param dtype The data type to be loaded. + * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be + * stored. The number lanes of the mask must be equal to the number of lanes in value. */ - TVM_DLL PrimExpr vload(Array begin, DataType dtype) const; + TVM_DLL PrimExpr vload(Array begin, DataType dtype, + Optional predicate = NullOpt) const; /*! * \brief Create a Stmt that does a vector store at begin index. * \param begin The beginning index * \param value The value to be stored. + * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be + * stored. The number lanes of the mask must be equal to the number of lanes in value. */ - TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; + TVM_DLL Stmt vstore(Array begin, PrimExpr value, + Optional predicate = NullOpt) const; /*! * \brief Get a flattened version of the buffer diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index b3673c4bb356..d9b65dc8745c 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -631,7 +631,7 @@ class BufferLoadNode : public PrimExprNode { /*! \brief The indices location to be loaded. */ Array indices; /*! \brief The predicate mask for loading values. */ - PrimExpr predicate; + Optional predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); @@ -680,7 +680,7 @@ class BufferLoadNode : public PrimExprNode { class BufferLoad : public PrimExpr { public: TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, - PrimExpr predicate = PrimExpr(), Span span = Span()); + Optional predicate = NullOpt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index b60e7a80cfae..c77254ed34cb 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -232,7 +232,7 @@ class BufferStoreNode : public StmtNode { /*! \brief The indices location to be stored. */ Array indices; /*! \brief The predicate mask for storing values. */ - PrimExpr predicate; + Optional predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("buffer", &buffer); @@ -265,7 +265,7 @@ class BufferStoreNode : public StmtNode { class BufferStore : public Stmt { public: TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices, - PrimExpr predicate = PrimExpr(), Span span = Span()); + Optional predicate = NullOpt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 1550ebc49efa..8289ea96ae25 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1281,7 +1281,9 @@ def buffer_store( The indices location to be stored. predicate : Optional[PrimExpr] - A vector mask of int1 values indicating which lanes of a vector are to be stored. + A vector mask of boolean values indicating which lanes of a vector are to be + stored. The number lanes of the mask must be equal to the number of lanes in + value. """ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index b6de8791dea1..61a11e0330a6 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -101,7 +101,7 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0, self, access_mask, ptr_type, content_lanes, offset, extent # type: ignore ) - def vload(self, begin, dtype=None): + def vload(self, begin, dtype=None, predicate=None): """Generate an Expr that loads dtype from begin index. Parameters @@ -113,6 +113,11 @@ def vload(self, begin, dtype=None): The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype + predicate : Optional[PrimExpr] + A vector mask of boolean values indicating which lanes of a vector are to be + stored. The number lanes of the mask must be equal to the number of lanes in + value. + Returns ------- load : Expr @@ -120,9 +125,9 @@ def vload(self, begin, dtype=None): """ begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin dtype = dtype if dtype else self.dtype - return _ffi_api.BufferVLoad(self, begin, dtype) # type: ignore + return _ffi_api.BufferVLoad(self, begin, dtype, predicate) # type: ignore - def vstore(self, begin, value): + def vstore(self, begin, value, predicate=None): """Generate a Stmt that store value into begin index. Parameters @@ -133,64 +138,18 @@ def vstore(self, begin, value): value : Expr The value to be stored. + predicate : Optional[PrimExpr] + A vector mask of boolean values indicating which lanes of a vector are to be + stored. The number lanes of the mask must be equal to the number of lanes in + value. + Returns ------- store : Stmt The corresponding store stmt. """ begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin - return _ffi_api.BufferVStore(self, begin, value) # type: ignore - - def load(self, indices, predicate=None): - """ - Load values at specified indices from buffer. - - Longhand notation that can be used for complex buffer load - expressions. For example, when the load involves predication. - - Parameters - ---------- - indices : List[PrimExpr] - The buffer indices to load values from. - - predicate : Optional[PrimExpr] - A vector mask of int1 values indicating which lanes of a vector are to be loaded. - - Returns - ------- - BufferLoad - A buffer load Expr. - """ - from .expr import BufferLoad # pylint: disable=import-outside-toplevel - - return BufferLoad(self, indices, predicate) - - def store(self, value, indices, predicate=None): - """ - Store given value at the specified indices in the buffer. - - Longhand notation that can be used for complex buffer store - statements. For example, when the store involves predication. - - Parameters - ---------- - value : PrimExpr - The value to be stored. - - indices : List[PrimExpr] - The buffer indices to store values to. - - predicate : Optional[PrimExpr] - A vector mask of int1 values indicating which lanes of a vector are to be stored. - - Returns - ------- - BufferStore - A buffer store Stmt. - """ - from .stmt import BufferStore # pylint: disable=import-outside-toplevel - - return BufferStore(self, value, indices, predicate) + return _ffi_api.BufferVStore(self, begin, value, predicate) # type: ignore def scope(self): """Return the storage scope associated with this buffer. diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index b9ea2c414d26..baed599add5d 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1099,7 +1099,9 @@ class BufferLoad(PrimExprWithOp): The location of this expression in the source code. predicate : Optional[PrimExpr] - A vector mask of int1 values indicating which lanes of a vector are to be loaded. + A vector mask of boolean values indicating which lanes of a vector are to be + stored. The number lanes of the mask must be equal to the number of lanes in + value. """ buffer: Buffer diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 6f8ce42cd938..aa3b17a7a12f 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -225,7 +225,9 @@ class BufferStore(Stmt): The indices location to be stored. predicate : Optional[PrimExpr] - A vector mask of int1 values indicating which lanes of a vector are to be stored. + A vector mask of boolean values indicating which lanes of a vector are to be + stored. The number lanes of the mask must be equal to the number of lanes in + value. span : Optional[Span] The location of the stmt in the source code. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 0c4248bd3f26..08d5e9379dc6 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -233,15 +233,16 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // "T.vscale" and the compile target uses a scalable architecture extension like // SVE, we can make some assumptions about the value of vscale and iterate over a // space of pre-defined values to attempt to prove the expression. + Target curr_target = Target::Current(); if (ContainsVscaleCall(simplified)) { - if (TargetHasSVE()) { + if (TargetHasSVE(curr_target)) { return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues); } LOG(WARNING) << "The expression contains scalable values. An attempt to prove by substituting " "with known values of vscale was not performed. This proof currently only supports " "AArch64 SVE targets, but the target was " - << Target::Current(); + << curr_target; } return false; } diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 2f9d640ee712..ecd3b25bfc67 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -370,7 +370,7 @@ class ConstIntBoundAnalyzer::Impl return VisitLeftShift(op); } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); - } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE()) { + } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE(Target::Current())) { unsigned int max_val = *std::max_element(kAArch64VScaleValues.begin(), kAArch64VScaleValues.end()); return MakeBound(1, max_val); diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index 2df035d6151a..e5f3bc28ba52 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -93,8 +93,7 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr return can_prove_expr; } -bool TargetHasSVE() { - Target current_target = Target::Current(); +bool TargetHasSVE(Target current_target) { bool has_sve{false}; if (current_target.defined()) { has_sve = current_target->GetFeature("has_sve").value_or(Bool(false)); diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h index 8e807eb3b839..06ff8104e928 100644 --- a/src/arith/scalable_expression.h +++ b/src/arith/scalable_expression.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -79,9 +80,10 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr /*! * \brief Check whether the compilation target supports SVE + * \param target The target to check. * \return Whether SVE is supported */ -bool TargetHasSVE(); +bool TargetHasSVE(Target target); } // namespace arith } // namespace tvm diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 121d531fc9e0..17353561ee54 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -525,7 +525,7 @@ Var EnvThread(String thread_tag, DataType dtype) { } void BufferStore(Buffer buffer, PrimExpr value, Array indices, - PrimExpr predicate = PrimExpr()) { + Optional predicate = NullOpt) { runtime::DataType buffer_dtype = buffer->dtype; bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector(); bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector(); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 078d34fbba7b..c2d504c2fd7e 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -280,7 +280,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc indices = d->AsDoc(store->indices, p->Attr("indices")); ExprDoc predicate = d->AsDoc(store->predicate, p->Attr("predicate")); return ExprStmtDoc( - buffer->Attr("store")->Call({value, indices}, {"predicate"}, {predicate})); + buffer->Attr("vstore")->Call({indices, value}, {"predicate"}, {predicate})); } return AssignDoc( @@ -297,7 +297,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (load->predicate.defined()) { ExprDoc indices = d->AsDoc(load->indices, p->Attr("indices")); ExprDoc predicate = d->AsDoc(load->predicate, p->Attr("predicate")); - return buffer->Attr("load")->Call({indices}, {"predicate"}, {predicate}); + return buffer->Attr("vload")->Call({indices}, {"predicate"}, {predicate}); } return buffer[BufferIndices(load->indices, p->Attr("indices"), d)]; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 562f7a8747b3..6098a3f32f0d 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1668,7 +1668,7 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { } void CodeGenLLVM::BufferAccessHelper( - Buffer buffer, Array indices, PrimExpr predicate, DataType value_dtype, + Buffer buffer, Array indices, Optional predicate, DataType value_dtype, std::function make_instruction) { @@ -1752,7 +1752,7 @@ void CodeGenLLVM::BufferAccessHelper( llvm::Value* predicate_value = nullptr; if (predicate.defined()) { - predicate_value = MakeValue(predicate); + predicate_value = MakeValue(predicate.value()); } TypedPointer buffer_ptr = @@ -1776,21 +1776,28 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, llvm::Value* predicate, int alignment, bool is_volatile) { -#if TVM_LLVM_VERSION >= 110 llvm::Instruction* load = nullptr; if (predicate != NULL) { + ICHECK(!is_volatile) + << "The masked load intrinsic does not support declaring load as volatile."; +#if TVM_LLVM_VERSION >= 130 load = builder_->CreateMaskedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), predicate); +#elif TVM_LLVM_VERSION >= 110 + load = builder_->CreateMaskedLoad(buffer_ptr.addr, llvm::Align(alignment), predicate); +#else + load = builder_->CreateMaskedLoad(buffer_ptr.addr, alignment, predicate); +#endif } else { +#if TVM_LLVM_VERSION >= 110 load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), is_volatile); - } #elif TVM_LLVM_VERSION >= 80 - auto load = - builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); + load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); #else - auto load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); + load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif + } loads.push_back(load); return load; @@ -1922,17 +1929,24 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { if (subelement_i != -1) { to_store = builder_->CreateExtractElement(value, subelement_i); } -#if TVM_LLVM_VERSION >= 110 + if (predicate != NULL) { + ICHECK(!is_volatile) + << "The masked store intrinsic does not support declaring store as volatile."; +#if TVM_LLVM_VERSION >= 110 store = builder_->CreateMaskedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), predicate); +#else + store = builder_->CreateMaskedStore(to_store, buffer_ptr.addr, alignment, predicate); +#endif } else { +#if TVM_LLVM_VERSION >= 110 store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), is_volatile); - } #else - store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); #endif + } return store; }; diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 832ed2fbfeb8..9b28e892ef9b 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -330,8 +330,9 @@ class CodeGenLLVM : public ExprFunctor, * * \param indices The indices at which the buffer is being accessed. * - * \param predicate A vector mask of int1 values indicating which lanes of a vector are to be - * stored. + * \param predicate A vector mask of boolean values indicating which lanes of a + * vector are to be stored. The number lanes of the mask must be equal to the + * number of lanes in value. * * \param value_dtype The datatype to be read from (BufferLoad) or * written to (BufferStore) the buffer. @@ -354,7 +355,7 @@ class CodeGenLLVM : public ExprFunctor, * - Should return the generated expression. */ void BufferAccessHelper( - Buffer buffer, Array indices, PrimExpr predicate, DataType value_dtype, + Buffer buffer, Array indices, Optional predicate, DataType value_dtype, std::function make_instruction); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 009fc1672ace..5f6f493e08a3 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -764,6 +764,7 @@ void CodeGenC::VisitStmt_(const DeclBufferNode* op) { this->PrintStmt(op->body); void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; DataType value_dtype = op->dtype; PrimExpr index = op->indices[0]; @@ -823,6 +824,7 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI void CodeGenC::VisitStmt_(const BufferStoreNode* op) { ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; DataType value_dtype = op->value.dtype(); DataType element_dtype = op->buffer->dtype; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index ba925056a379..f62e0db7ffdf 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -459,6 +459,7 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // // to ensure correctness in the case of nested-expression // do not try to lift common printings from each case ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; DataType value_dtype = op->dtype; PrimExpr index = op->indices[0]; @@ -531,6 +532,8 @@ void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { CHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; + DataType value_dtype = op->value.dtype(); DataType element_dtype = op->buffer->dtype; PrimExpr index = op->indices[0]; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index d71187922874..025605333138 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -399,37 +399,44 @@ Buffer Buffer::GetFlattenedBuffer() const { } } -PrimExpr Buffer::vload(Array begin, DataType value_dtype) const { +PrimExpr Buffer::vload(Array begin, DataType value_dtype, + Optional predicate) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); ICHECK(value_dtype.element_of() == n->dtype.element_of() && - value_dtype.lanes() % n->dtype.lanes() == 0) + value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) << "Cannot load " << value_dtype << " from buffer of " << n->dtype; Array indices = begin; - int factor = value_dtype.lanes() / n->dtype.lanes(); - if (factor > 1) { - indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor)); + PrimExpr base = indices[indices.size() - 1]; + if (value_dtype.is_fixed_length_vector()) { + int factor = value_dtype.lanes() / n->dtype.lanes(); + if (factor > 1 && base.dtype().is_scalar()) { + indices.Set(indices.size() - 1, Ramp(base, 1, factor)); + } } - return BufferLoad(*this, indices); + return BufferLoad(*this, indices, predicate); } -Stmt Buffer::vstore(Array begin, PrimExpr value) const { +Stmt Buffer::vstore(Array begin, PrimExpr value, Optional predicate) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); DataType value_dtype = value.dtype(); ICHECK(value_dtype.element_of() == n->dtype.element_of() && - value_dtype.lanes() % n->dtype.lanes() == 0) + value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) << "Cannot store " << value_dtype << " to buffer of " << n->dtype; Array indices = begin; - int factor = value_dtype.lanes() / n->dtype.lanes(); - if (factor > 1) { - indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor)); + PrimExpr base = indices[indices.size() - 1]; + if (value_dtype.is_fixed_length_vector()) { + int factor = value_dtype.lanes() / n->dtype.lanes(); + if (factor > 1 && base.dtype().is_scalar()) { + indices.Set(indices.size() - 1, Ramp(base, 1, factor)); + } } - return BufferStore(*this, value, indices); + return BufferStore(*this, value, indices, predicate); } String Buffer::scope() const { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index b54be0796372..0fcc8608e9f5 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -772,12 +772,20 @@ void BufferLoadNode::LegalizeDType() { } } -BufferLoad::BufferLoad(Buffer buffer, Array indices, PrimExpr predicate, Span span) { +BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional predicate, + Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() << "-dimensional indices provided."; + if (predicate.defined()) { + DataType predicate_element_dtype = predicate.value().dtype().element_of(); + ICHECK(predicate_element_dtype.is_bool()) + << "Predicate mask elements must be boolean values, but got " << predicate_element_dtype + << "."; + } + ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->indices = std::move(indices); @@ -788,9 +796,8 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, PrimExpr predicat } TVM_REGISTER_GLOBAL("tir.BufferLoad") - .set_body_typed([](Buffer buffer, Array indices, PrimExpr predicate, Span span) { - return BufferLoad(buffer, indices, predicate, span); - }); + .set_body_typed([](Buffer buffer, Array indices, Optional predicate, + Span span) { return BufferLoad(buffer, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferLoadNode); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6bd4d97ce1c6..5df76450ff1e 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -458,8 +458,8 @@ TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) TVM_REGISTER_NODE_TYPE(EvaluateNode); // BufferStore -BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate, - Span span) { +BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, + Optional predicate, Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -477,29 +477,39 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, ICHECK(!(is_index_scalable && is_buffer_dtype_scalable)) << "Index dtype and buffer dtype can't both be scalable."; - if (is_index_scalable || is_buffer_dtype_scalable) { - ICHECK(is_value_dtype_scalable) << "Can't store non-scalable data into scalable buffer"; + if (predicate.defined()) { + bool is_predicate_dtype_scalable = predicate.value().dtype().is_scalable_vector(); + ICHECK_EQ(is_value_dtype_scalable, is_predicate_dtype_scalable) + << "Predicate mask dtype and value dtype must both be scalable."; } - int index_lanes; - if (indices.empty()) { - index_lanes = 1; - } else if (is_index_scalable) { - index_lanes = indices.back().dtype().vscale_factor(); - } else { - index_lanes = indices.back().dtype().lanes(); + if (is_index_scalable || is_buffer_dtype_scalable) { + ICHECK(is_value_dtype_scalable) << "Can't store non-scalable data into scalable buffer"; } - int buffer_lanes = - is_buffer_dtype_scalable ? buffer->dtype.vscale_factor() : buffer->dtype.lanes(); - int value_dtype_lanes = - is_value_dtype_scalable ? value.dtype().vscale_factor() : value.dtype().lanes(); + int index_lanes = indices.empty() ? 1 : indices.back().dtype().get_lanes_or_vscale_factor(); + int buffer_lanes = buffer->dtype.get_lanes_or_vscale_factor(); + int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor(); ICHECK_EQ(index_lanes * buffer_lanes, value_dtype_lanes) << "Cannot store value with " << value_dtype_lanes << ", expected value with " << index_lanes * buffer_lanes << " (" << index_lanes << " index lanes * " << buffer_lanes << " buffer element lanes)"; + if (predicate.defined()) { + DataType predicate_dtype = predicate.value().dtype(); + int predicate_dtype_lanes = predicate_dtype.get_lanes_or_vscale_factor(); + ICHECK_EQ(value_dtype_lanes, predicate_dtype_lanes) + << "Got a predicate mask with " << predicate_dtype_lanes + << " lanes, but trying to store a value with " << value_dtype_lanes + << " lanes. The number of lanes must match."; + + DataType predicate_element_dtype = predicate_dtype.element_of(); + ICHECK(predicate_element_dtype.is_bool()) + << "Predicate mask elements must be boolean values, but got " << predicate_element_dtype + << "."; + } + runtime::DataType buffer_dtype; if (is_index_scalable || is_buffer_dtype_scalable) { buffer_dtype = buffer->dtype.with_scalable_vscale_factor(buffer_lanes * index_lanes); @@ -524,7 +534,8 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, } TVM_REGISTER_GLOBAL("tir.BufferStore") - .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate, + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, + Optional predicate, Span span) { return BufferStore(buffer, value, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 182369ba03fa..d3319004e2f1 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -72,7 +72,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { return Broadcast(e, CreateNewLanes(is_scalable, lanes)); } -bool EnableBufferLevelPredication() { +bool EnableBufferLevelPredication(Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); Optional enable_buffer_predication = pass_ctx->GetConfig("tir.enable_buffer_level_predication"); @@ -81,14 +81,14 @@ bool EnableBufferLevelPredication() { } // Use buffer-level predication by default for AArch64 SVE targets - return arith::TargetHasSVE(); + return arith::TargetHasSVE(target); } /*! * \brief A pass that tries to rewrite buffer accesses (loads and stores) with a * predicate expression where possible. * - * \note For now we start with a minimalized case targeting block-level predicates + * \note For now we start with a minimal case targeting block-level predicates * produced by the split schedule primitive, with the potential for predicating * more complex terms in the future if needed. * @@ -101,9 +101,9 @@ bool EnableBufferLevelPredication() { * * After: * for i_0 in T.serial(4): - * predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14) - * A_load = T.meta_var(A.load([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate)) - * B.store(A_load, [T.Ramp(i_0 * 4, 1, 4)], predicate=predicate) + * predicate = T.get_active_lane_mask("uint1x4", i_0 * 4, 14) + * A_load = T.meta_var(A.vload([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate)) + * B.vstore([T.Ramp(i_0 * 4, 1, 4)], A_load, predicate=predicate) */ class TryPredicateBufferAccesses : public StmtExprMutator { public: @@ -171,7 +171,7 @@ class TryPredicateBufferAccesses : public StmtExprMutator { } DataType buf_predicate_dtype = - DataType(DataType::kInt, 1, ramp->dtype.get_lanes_or_vscale_factor(), + DataType(DataType::kUInt, 1, ramp->dtype.get_lanes_or_vscale_factor(), ramp->dtype.is_scalable_vector()); Call lane_mask = Call(buf_predicate_dtype, builtin::get_active_lane_mask(), {base_, limit_}); @@ -291,7 +291,8 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype, 0), IntImm(var->dtype, 1), var_lanes); } @@ -682,8 +683,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor success_stmt_pair = TryPredicateBufferAccesses().Run(then_case, condition); bool can_remove_if_then_else = success_stmt_pair.first; @@ -791,6 +792,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor let_binding_; // vectorizable property OpAttrMap op_vectorizable_ = Op::GetAttrMap("TVectorizable"); + /*! \brief The current target context. */ + Target target_; // mutate array, with given lane requirement // when finished, p_lane updates the lane requirement. @@ -860,22 +863,41 @@ class Vectorizer : public StmtMutator, public ExprFunctor(tvm::attr::kTarget)) { + target_ = opt_target.value(); + } + } + Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { auto* extent_as_int = op->extent.as(); if (!extent_as_int || extent_as_int->value < 1) { bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall); - ICHECK(is_scalable_expr && arith::TargetHasSVE()) - << "Failed to vectorize loop with extent " << op->extent << " for target " - << Target::Current(); + ICHECK(is_scalable_expr && arith::TargetHasSVE(target_)) + << "Failed to vectorize loop with extent " << op->extent << " for target " << target_; } ICHECK(is_zero(op->min)); - return Vectorizer(op->loop_var, op->extent)(op->body); + return Vectorizer(op->loop_var, op->extent, target_)(op->body); } else { return StmtMutator::VisitStmt_(op); } } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tvm::attr::kTarget) { + Target previous_target = target_; + target_ = op->node.as().value(); + Stmt new_op = StmtMutator::VisitStmt_(op); + target_ = previous_target; + return new_op; + } + return StmtMutator::VisitStmt_(op); + } + + private: + Target target_ = Target::Current(); }; class VectorizeSkipper : public StmtMutator { @@ -900,7 +922,7 @@ Pass VectorizeLoop(bool enable_vectorize) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); if (enable_vectorize) { - n->body = LoopVectorizer()(std::move(n->body)); + n->body = LoopVectorizer(n->attrs)(std::move(n->body)); } else { n->body = VectorizeSkipper()(std::move(n->body)); } diff --git a/tests/python/codegen/test_target_codegen.py b/tests/python/codegen/test_target_codegen.py new file mode 100644 index 000000000000..3041ad27e248 --- /dev/null +++ b/tests/python/codegen/test_target_codegen.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +from tvm.script import tir as T + + +@tvm.testing.exclude_targets("llvm") +def test_buffer_store_predicate_not_supported(target): + @T.prim_func + def func(b: T.handle): + B = T.match_buffer(b, (8,), "float32") + B.vstore([T.Ramp(0, 2, 4)], T.Broadcast(1.0, 4), predicate=T.Broadcast(T.bool(True), 4)) + + err_msg = "Predicated buffer store is not supported." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target(target): + tvm.build(func) + + +@tvm.testing.exclude_targets("llvm") +def test_buffer_load_predicate_not_supported(target): + @T.prim_func + def func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (8,), "float32") + B = T.match_buffer(b, (8,), "float32") + for i_0 in range(4): + B.vstore( + [T.Ramp(0, 2, 4)], + A.vload([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)), + ) + + err_msg = "Predicated buffer load is not supported." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target(target): + tvm.build(func) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 75e31365f432..251e625b8173 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -771,7 +771,7 @@ def test_get_active_lane_mask(): def before(a: T.handle): A = T.match_buffer(a, (30,), "int1") for i in range(T.ceildiv(30, T.vscale() * 4)): - A[i : i + T.vscale() * 4] = T.get_active_lane_mask("int1xvscalex4", i, 30) + A[i : i + T.vscale() * 4] = T.get_active_lane_mask("uint1xvscalex4", i, 30) with tvm.target.Target(target): out = tvm.build(before) diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index f1316ae3cee0..f50d63878e4f 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -1109,5 +1109,34 @@ def func(): built = tvm.build(func, target="llvm") +def test_invalid_volatile_masked_buffer_load(): + @T.prim_func + def func(b: T.handle): + B = T.match_buffer(b, [4]) + a = T.allocate([4], "float32", scope="global") + T.attr(a, "volatile_scope", 1) + A = T.Buffer([4], data=a) + B[0:4] = A.vload([T.Ramp(0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)) + + err_msg = "The masked load intrinsic does not support declaring load as volatile." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target("llvm"): + tvm.build(func) + + +def test_invalid_volatile_masked_buffer_store(): + @T.prim_func + def func(): + a = T.allocate([4], "float32", scope="global") + T.attr(a, "volatile_scope", 1) + A = T.Buffer([4], data=a) + A.vstore([T.Ramp(0, 1, 4)], T.Broadcast(0.0, 4), predicate=T.Broadcast(T.bool(True), 4)) + + err_msg = "The masked store intrinsic does not support declaring store as volatile." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target("llvm"): + tvm.build(func) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 31a1317e6817..b886a6953330 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -468,6 +468,52 @@ def test_buffer_store_scalable_vec(): assert store.value.dtype == "int32xvscalex4" +def test_buffer_store_predicate_invalid_scalability(): + b = tvm.tir.decl_buffer((24,), "int32") + value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 4) + + err_msg = "Predicate mask dtype and value dtype must both be scalable." + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferStore(b, value, [index], predicate) + + +def test_buffer_store_predicate_invalid_lanes(): + b = tvm.tir.decl_buffer((24,), "int32") + value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 8 * tvm.tir.vscale()) + + err_msg = ( + "Got a predicate mask with 8 lanes, but trying to store a " + "value with 4 lanes. The number of lanes must match." + ) + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferStore(b, value, [index], predicate) + + +def test_buffer_store_predicate_elements_invalid_type(): + b = tvm.tir.decl_buffer((24,), "int32") + value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + + err_msg = "Predicate mask elements must be boolean values, but got int32." + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferStore(b, value, [index], predicate) + + +def test_buffer_load_predicate_elements_invalid_type(): + b = tvm.tir.decl_buffer((24,), "int32") + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + + err_msg = "Predicate mask elements must be boolean values, but got int32." + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferLoad(b, [index], predicate) + + def test_scalable_vec_cast(): b = tvm.tir.decl_buffer((24,), "float32") value = tvm.tir.expr.Broadcast(1, 12 * tvm.tir.vscale()).astype("float32xvscalex12") diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 44b886eedd6a..f72e6e08a05d 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -547,15 +547,16 @@ def expected(a: T.handle, b: T.handle): T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) for i_0 in range(4): load_a = T.meta_var( - A.load( - [T.Ramp(i_0 * 4, 1, 4)], predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14) + A.vload( + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), ) ) add_1 = T.meta_var(load_a + T.Broadcast(T.float32(1), 4)) - B.store( - add_1, + B.vstore( [T.Ramp(i_0 * 4, 1, 4)], - predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + add_1, + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), ) mod = tvm.IRModule.from_expr(before) @@ -610,15 +611,15 @@ def expected(a: T.handle, b: T.handle): B = T.match_buffer(b, (16,), "float32") T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) for i_0 in range(4): - A.store( - T.Broadcast(T.float32(2), 4), + A.vstore( [T.Ramp(i_0 * 4, 1, 4)], - predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + T.Broadcast(T.float32(2), 4), + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), ) - B.store( - T.Broadcast(T.float32(1), 4), + B.vstore( [T.Ramp(i_0 * 4, 1, 4)], - predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + T.Broadcast(T.float32(1), 4), + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), ) before_mod = tvm.IRModule.from_expr(before) @@ -665,8 +666,8 @@ def expected(a: T.handle, b: T.handle): def test_vectorize_with_explicitly_disabled_buffer_level_predication(): - # Since the target is has the SVe feature, buffer level predication is enabled - # by default. However, it has been explicitely disabled by the pass context + # Since the target has the SVE feature, buffer level predication is enabled + # by default. However, it has been explicitly disabled by the pass context # option, so no buffer-level predicates should be added. @T.prim_func def before(a: T.handle, b: T.handle): @@ -689,10 +690,82 @@ def expected(a: T.handle, b: T.handle): mod = tvm.IRModule.from_expr(before) with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": False}): - with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + with tvm.target.Target(sve_target): after = tvm.tir.transform.VectorizeLoop()(mod)["main"] tvm.ir.assert_structural_equal(after, expected) +def test_vectorize_and_predicate_buffer_load_stores_with_sve_func_attr_target(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": sve_target}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True), "target": sve_target}) + for i_0 in range(4): + load_a = T.meta_var( + A.vload( + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + ) + add_1 = T.meta_var(load_a + T.Broadcast(T.float32(1), 4)) + B.vstore( + [T.Ramp(i_0 * 4, 1, 4)], + add_1, + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + + mod = tvm.IRModule.from_expr(before) + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_buffer_load_stores_with_sve_attr_scope_target(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.attr(sve_target, "target", 0): + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + with T.attr(sve_target, "target", 0): + for i_0 in range(4): + load_a = T.meta_var( + A.vload( + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + ) + add_1 = T.meta_var(load_a + T.Broadcast(T.float32(1), 4)) + B.vstore( + [T.Ramp(i_0 * 4, 1, 4)], + add_1, + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + + mod = tvm.IRModule.from_expr(before) + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 13e6aec285a6..9e77fa090021 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -956,15 +956,15 @@ def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (256, 256), "float32") T.func_attr({"global_symbol": "func"}) - a_load = T.meta_var(A.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4))) - A.store(a_load, [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + a_load = T.meta_var(A.vload([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(T.bool(False), 4))) + A.vstore([0, T.Ramp(0, 2, 4)], a_load, predicate=T.Broadcast(T.bool(False), 4)) expected_output = """ # from tvm.script import tir as T @T.prim_func def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): - A.store(A.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4)), [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + A.vstore([0, T.Ramp(0, 2, 4)], A.vload([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(T.bool(False), 4)), predicate=T.Broadcast(T.bool(False), 4)) """ _assert_print(main, expected_output) @@ -977,13 +977,15 @@ def test_predicated_buffer_load_store(): b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), } buffer_load = tir.BufferLoad( - buffer=buffer_map[b], indices=[0, tir.Ramp(0, 4, 4)], predicate=tir.Broadcast(0, 4) + buffer=buffer_map[b], + indices=[0, tir.Ramp(0, 4, 4)], + predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4), ) body = tir.BufferStore( buffer=buffer_map[a], value=buffer_load, indices=[0, tir.Ramp(0, 2, 4)], - predicate=tir.Broadcast(0, 4), + predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4), ) func = tir.PrimFunc( params=[a, b], @@ -997,7 +999,7 @@ def test_predicated_buffer_load_store(): @T.prim_func(private=True) def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): - A.store(B.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4)), [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + A.vstore([0, T.Ramp(0, 2, 4)], B.vload([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(T.bool(False), 4)), predicate=T.Broadcast(T.bool(False), 4)) """ _assert_print(func, expected_output) @@ -1010,18 +1012,35 @@ def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (256, 256), "float32") T.func_attr({"global_symbol": "func"}) - mask = T.meta_var(T.get_active_lane_mask("int1xvscalex4", 0, 13)) - a_load = T.meta_var(A.load([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=mask)) - A.store(a_load, [0, T.Ramp(0, 2, T.vscale() * 4)], predicate=mask) + mask = T.meta_var(T.get_active_lane_mask("uint1xvscalex4", 0, 13)) + a_load = T.meta_var(A.vload([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=mask)) + A.vstore([0, T.Ramp(0, 2, T.vscale() * 4)], a_load, predicate=mask) expected_output = """ # from tvm.script import tir as T @T.prim_func def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): - A.store(\ -A.load([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=T.get_active_lane_mask("int1xvscalex4", 0, 13)), \ -[0, T.Ramp(0, 2, T.vscale() * 4)], predicate=T.get_active_lane_mask("int1xvscalex4", 0, 13)) + A.vstore([0, T.Ramp(0, 2, T.vscale() * 4)], A.vload([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=T.get_active_lane_mask("uint1xvscalex4", 0, 13)), predicate=T.get_active_lane_mask("uint1xvscalex4", 0, 13)) + """ + _assert_print(main, expected_output) + + +def test_vload_with_explicit_scalable_data_type(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128,), "float32") + B = T.match_buffer(b, (128,), "float32") + B[0 : T.vscale() * 4] = A.vload([T.Ramp(0, 1, T.vscale() * 4)], dtype="float32xvscalex4") + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): + B[0:T.vscale() * 4] = A[0:T.vscale() * 4] """ _assert_print(main, expected_output) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index bcc318caf6f2..ee404f08efb8 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3358,8 +3358,10 @@ def func(a: T.handle, b: T.handle): A = T.match_buffer(a, (4,), "float32") B = T.match_buffer(b, (8,), "float32") for i_0 in range(4): - load_a = T.meta_var(A.load([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(1.0, 4))) - B.store(load_a, [T.Ramp(0, 2, 4)], predicate=T.Broadcast(1.0, 4)) + load_a = T.meta_var( + A.vload([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)) + ) + B.vstore([T.Ramp(0, 2, 4)], load_a, predicate=T.Broadcast(T.bool(True), 4)) return func From b5a26470e68c83c79829effa0eaf26ca32d4235a Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 14 May 2024 13:34:58 +0000 Subject: [PATCH 4/7] Fix lint Change-Id: Idd3f3593fe524f3444487c520d947dfd53386db0 --- src/tir/transforms/vectorize_loop.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index d3319004e2f1..aa62d5850513 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -863,7 +863,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(tvm::attr::kTarget)) { target_ = opt_target.value(); } From eb1667e4d03c612b2751b3655c7268c0a3e62d24 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 14 May 2024 15:18:25 +0000 Subject: [PATCH 5/7] Fix some failing tests * vload/vstore updates that were missed previously * int1 -> bool updates * fix gpu target tests Fixes a test and updates comments referencing old load/store api Change-Id: I26a0c480d2dedee442ca0116909a7751d1dfa9ac --- src/script/printer/tir/buffer.cc | 4 +- tests/python/codegen/test_target_codegen.py | 40 ++++++++++++++++++- .../test_tir_transform_vectorize.py | 6 +-- .../test_tvmscript_ir_builder_tir.py | 2 +- 4 files changed, 44 insertions(+), 8 deletions(-) diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index c2d504c2fd7e..87db53061ceb 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -275,7 +275,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); ExprDoc value = d->AsDoc(store->value, p->Attr("value")); - // Use .store(...) syntax when there is a predicate + // Use .vstore(...) syntax when there is a predicate if (store->predicate.defined()) { ExprDoc indices = d->AsDoc(store->indices, p->Attr("indices")); ExprDoc predicate = d->AsDoc(store->predicate, p->Attr("predicate")); @@ -293,7 +293,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); - // Use .load(...) syntax when there is a predicate + // Use .vload(...) syntax when there is a predicate if (load->predicate.defined()) { ExprDoc indices = d->AsDoc(load->indices, p->Attr("indices")); ExprDoc predicate = d->AsDoc(load->predicate, p->Attr("predicate")); diff --git a/tests/python/codegen/test_target_codegen.py b/tests/python/codegen/test_target_codegen.py index 3041ad27e248..bae15b5377e3 100644 --- a/tests/python/codegen/test_target_codegen.py +++ b/tests/python/codegen/test_target_codegen.py @@ -21,7 +21,7 @@ from tvm.script import tir as T -@tvm.testing.exclude_targets("llvm") +@tvm.testing.parametrize_targets("c") def test_buffer_store_predicate_not_supported(target): @T.prim_func def func(b: T.handle): @@ -34,7 +34,25 @@ def func(b: T.handle): tvm.build(func) -@tvm.testing.exclude_targets("llvm") +@tvm.testing.parametrize_targets("cuda", "opencl", "metal", "rocm", "vulkan -from_device=0") +def test_buffer_store_predicate_not_supported_gpu(target): + @T.prim_func + def func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (2, 3), "float32") + B = T.match_buffer(b, (6,), "float32") + T.func_attr({"global_symbol": "main"}) + for i_0 in T.thread_binding(3, thread="threadIdx.x"): + B.vstore( + [T.Ramp(i_0, 1, 4)], T.Broadcast(1.0, 4), predicate=T.Broadcast(T.bool(True), 4) + ) + + err_msg = "Predicated buffer store is not supported." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target(target): + tvm.build(func) + + +@tvm.testing.parametrize_targets("c") def test_buffer_load_predicate_not_supported(target): @T.prim_func def func(a: T.handle, b: T.handle): @@ -52,5 +70,23 @@ def func(a: T.handle, b: T.handle): tvm.build(func) +@tvm.testing.parametrize_targets("cuda", "opencl", "metal", "rocm", "vulkan -from_device=0") +def test_buffer_load_predicate_not_supported_gpu(target): + @T.prim_func + def func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (8,), "float32") + B = T.match_buffer(b, (8,), "float32") + for i_0 in T.thread_binding(3, thread="threadIdx.x"): + B.vstore( + [T.Ramp(0, 2, 4)], + A.vload([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)), + ) + + err_msg = "Predicated buffer load is not supported." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target(target): + tvm.build(func) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index f72e6e08a05d..e02c227b05b7 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -186,10 +186,10 @@ def main(a: T.handle, n: T.int32, x: T.int32): T.float32(1), extent ) else: - A.store( - T.Broadcast(T.float32(2), T.vscale() * 4), + A.vstore( [T.Ramp(0, 1, T.vscale() * 4)], - predicate=T.get_active_lane_mask("int1xvscalex4", 0, n), + T.Broadcast(T.float32(2), T.vscale() * 4), + predicate=T.get_active_lane_mask("uint1xvscalex4", 0, n), ) with tvm.target.Target(target): diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index 4636646b9216..daad7f53140b 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -472,7 +472,7 @@ def test_ir_builder_tir_buffer_store_predicate(): buffer_a = T.Buffer((30,), "float32") value = T.broadcast(0.11, T.vscale() * 4) index = T.ramp(0, 1, T.vscale() * 4) - predicate = T.broadcast(1, T.vscale() * 4) + predicate = T.broadcast(T.bool(True), T.vscale() * 4) with IRBuilder() as ib: T.buffer_store(buffer_a, value, [index], predicate) From 941cd38b305464d38e1329971623344eb5dd296c Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Sun, 19 May 2024 11:19:37 +0000 Subject: [PATCH 6/7] Address comments - Correct doc strings - Correct typo in error message - Add some additional checks for BufferLoad Change-Id: Ie25563d569c0ed729ac915a6ba3a724a9e191014 --- include/tvm/tir/buffer.h | 2 +- python/tvm/tir/buffer.py | 3 +-- python/tvm/tir/expr.py | 3 +-- src/target/llvm/codegen_llvm.h | 4 ++-- src/tir/ir/expr.cc | 16 +++++++++++++- src/tir/transforms/inject_rolling_buffer.cc | 4 ++-- tests/python/tir-base/test_tir_nodes.py | 23 +++++++++++++++++++++ 7 files changed, 45 insertions(+), 10 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 8719476af98f..276198abb89c 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -210,7 +210,7 @@ class Buffer : public ObjectRef { * \param begin The beginning index * \param dtype The data type to be loaded. * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be - * stored. The number lanes of the mask must be equal to the number of lanes in value. + * loaded. The number lanes of the mask must be equal to the number of lanes in being loaded. */ TVM_DLL PrimExpr vload(Array begin, DataType dtype, Optional predicate = NullOpt) const; diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 61a11e0330a6..501d13b17e3d 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -115,8 +115,7 @@ def vload(self, begin, dtype=None, predicate=None): predicate : Optional[PrimExpr] A vector mask of boolean values indicating which lanes of a vector are to be - stored. The number lanes of the mask must be equal to the number of lanes in - value. + loaded. The number lanes of the mask must be equal to the number of lanes being loaded. Returns ------- diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index baed599add5d..c78bb9e7ecd0 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1100,8 +1100,7 @@ class BufferLoad(PrimExprWithOp): predicate : Optional[PrimExpr] A vector mask of boolean values indicating which lanes of a vector are to be - stored. The number lanes of the mask must be equal to the number of lanes in - value. + loaded. The number lanes of the mask must be equal to the number of lanes being loaded. """ buffer: Buffer diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 9b28e892ef9b..302a0d97b3f4 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -331,8 +331,8 @@ class CodeGenLLVM : public ExprFunctor, * \param indices The indices at which the buffer is being accessed. * * \param predicate A vector mask of boolean values indicating which lanes of a - * vector are to be stored. The number lanes of the mask must be equal to the - * number of lanes in value. + * vector are to be accessed. The number lanes of the mask must be equal to the + * number of lanes being accessed. * * \param value_dtype The datatype to be read from (BufferLoad) or * written to (BufferStore) the buffer. diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 0fcc8608e9f5..034d585573c9 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -780,7 +780,21 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Optionalpredicate.defined()) - << "Predicated buffer store is not current supported in the inject rolling buffer pass."; + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "the inject rolling buffer pass."; Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->predicate, op->span); // Then wrap the BufferStores in some Ifs to avoid recomputing elements for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) { diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index b886a6953330..eeedae1f127c 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -514,6 +514,29 @@ def test_buffer_load_predicate_elements_invalid_type(): tvm.tir.BufferLoad(b, [index], predicate) +def test_buffer_store_predicate_invalid_scalability(): + b = tvm.tir.decl_buffer((24,), "int32") + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 4) + + err_msg = "Predicate mask dtype and load indices must both be scalable." + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferLoad(b, [index], predicate) + + +def test_buffer_store_predicate_invalid_lanes(): + b = tvm.tir.decl_buffer((24,), "int32") + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 8 * tvm.tir.vscale()) + + err_msg = ( + "Got a predicate mask with 8 lanes, but trying to load a " + "vector with 4 lanes. The number of lanes must match." + ) + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferLoad(b, [index], predicate) + + def test_scalable_vec_cast(): b = tvm.tir.decl_buffer((24,), "float32") value = tvm.tir.expr.Broadcast(1, 12 * tvm.tir.vscale()).astype("float32xvscalex12") From cbd2e48e5fe2dd1a4971cbd60cc7562ab1b3714c Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 22 May 2024 08:59:11 +0000 Subject: [PATCH 7/7] Account for buffer lanes in predicate lane check Change-Id: I821210665e36c26bfa37fc9ed380b5d03c9e816e --- src/tir/ir/expr.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 034d585573c9..1506082003fd 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -787,9 +787,10 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Optionaldtype.get_lanes_or_vscale_factor(); int index_lanes = indices.empty() ? 1 : indices.back().dtype().get_lanes_or_vscale_factor(); int predicate_lanes = predicate_dtype.get_lanes_or_vscale_factor(); - ICHECK_EQ(index_lanes, predicate_lanes) + ICHECK_EQ(index_lanes * buffer_lanes, predicate_lanes) << "Got a predicate mask with " << predicate_lanes << " lanes, but trying to load a vector with " << index_lanes << " lanes. The number of lanes must match.";