Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SVE] Add support for representing and creating buffer-level predicates #16966

Merged
merged 7 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,11 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));
* \param buffer The buffer.
* \param value The value to be stored.
* \param indices The indices location to be stored.
* \param predicate A vector mask of boolean values indicating which lanes of a vector are to be
* stored. The number lanes of the mask must be equal to the number of lanes in value.
*/
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Optional<PrimExpr> predicate);

/*!
* \brief The prefetch hint for a buffer
Expand Down
10 changes: 8 additions & 2 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,20 @@ class Buffer : public ObjectRef {
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
* \param predicate A vector mask of boolean values indicating which lanes of a vector are to be
* loaded. The number lanes of the mask must be equal to the number of lanes in being loaded.
*/
TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype,
Optional<PrimExpr> predicate = NullOpt) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
* \param value The value to be stored.
* \param predicate A vector mask of boolean values indicating which lanes of a vector are to be
* stored. The number lanes of the mask must be equal to the number of lanes in value.
*/
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value,
Optional<PrimExpr> predicate = NullOpt) const;

/*!
* \brief Get a flattened version of the buffer
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,11 +630,14 @@ class BufferLoadNode : public PrimExprNode {
Buffer buffer;
/*! \brief The indices location to be loaded. */
Array<PrimExpr> indices;
/*! \brief The predicate mask for loading values. */
Optional<PrimExpr> predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("buffer", &buffer);
v->Visit("indices", &indices);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

Expand All @@ -647,6 +650,7 @@ class BufferLoadNode : public PrimExprNode {
hash_reduce(dtype);
hash_reduce(buffer);
hash_reduce(indices);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.BufferLoad";
Expand Down Expand Up @@ -675,7 +679,8 @@ class BufferLoadNode : public PrimExprNode {
*/
class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices,
Optional<PrimExpr> predicate = NullOpt, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,14 @@ class BufferStoreNode : public StmtNode {
PrimExpr value;
/*! \brief The indices location to be stored. */
Array<PrimExpr> indices;
/*! \brief The predicate mask for storing values. */
Optional<PrimExpr> predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("value", &value);
v->Visit("indices", &indices);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

Expand All @@ -248,6 +251,7 @@ class BufferStoreNode : public StmtNode {
hash_reduce(buffer);
hash_reduce(value);
hash_reduce(indices);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.BufferStore";
Expand All @@ -261,7 +265,7 @@ class BufferStoreNode : public StmtNode {
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Span span = Span());
Optional<PrimExpr> predicate = NullOpt, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ def _updater(data):
return _updater


def create_updater_16_to_17():
"""
Create an update to upgrade json from v0.16 to v0.17

Returns
-------
fupdater : function
The updater function
"""

def _update_predicate_argument(item, nodes):
null_value_idx = 0
null_value = nodes[null_value_idx]
assert str(null_value) == "{'type_key': ''}", f"Expected a null value but got {null_value}"
item["attrs"]["predicate"] = str(null_value_idx)
return item

node_map = {
"tir.BufferLoad": _update_predicate_argument,
"tir.BufferStore": _update_predicate_argument,
}

return create_updater(node_map, "0.16", "0.17")


def create_updater_15_to_16():
"""
Create an update to upgrade json from v0.15 to v0.16
Expand Down Expand Up @@ -316,5 +341,7 @@ def _from_version(data):
data = create_updater({}, "0.14", "0.15")(data)
if _from_version(data).startswith("0.15"):
data = create_updater_15_to_16()(data)
if _from_version(data).startswith("0.16"):
data = create_updater_16_to_17()(data)

return json.dumps(data, indent=2)
8 changes: 7 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,7 @@ def buffer_store(
buffer: Buffer, # pylint: disable=redefined-outer-name
value: PrimExpr,
indices: List[Union[PrimExpr, slice]],
predicate: Optional[PrimExpr] = None,
) -> None:
"""Buffer store node.

Expand All @@ -1278,6 +1279,11 @@ def buffer_store(

indices : List[Union[PrimExpr, slice]]
The indices location to be stored.

predicate : Optional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be
stored. The number lanes of the mask must be equal to the number of lanes in
value.
"""
from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel

Expand All @@ -1298,7 +1304,7 @@ def buffer_store(
if isinstance(value, bool) and buffer.dtype == "bool":
value = IntImm("bool", value)
return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member
buffer, value, expr_indices
buffer, value, expr_indices, predicate
)


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
elif isinstance(res, str):
# Ignore docstrings
pass
elif isinstance(res, tvm.tir.stmt.BufferStore):
T.buffer_store(res.buffer, res.value, res.indices, res.predicate)
else:
self.report_error(node, f"Parsing resulted in unexpected type {type(res)}")

Expand Down
17 changes: 13 additions & 4 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0,
self, access_mask, ptr_type, content_lanes, offset, extent # type: ignore
)

def vload(self, begin, dtype=None):
def vload(self, begin, dtype=None, predicate=None):
"""Generate an Expr that loads dtype from begin index.

Parameters
Expand All @@ -113,16 +113,20 @@ def vload(self, begin, dtype=None):
The data type to be loaded,
can be vector type which have lanes that is multiple of Buffer.dtype

predicate : Optional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be
loaded. The number lanes of the mask must be equal to the number of lanes being loaded.

Returns
-------
load : Expr
The corresponding load expression.
"""
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
dtype = dtype if dtype else self.dtype
return _ffi_api.BufferVLoad(self, begin, dtype) # type: ignore
return _ffi_api.BufferVLoad(self, begin, dtype, predicate) # type: ignore

def vstore(self, begin, value):
def vstore(self, begin, value, predicate=None):
"""Generate a Stmt that store value into begin index.

Parameters
Expand All @@ -133,13 +137,18 @@ def vstore(self, begin, value):
value : Expr
The value to be stored.

predicate : Optional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be
stored. The number lanes of the mask must be equal to the number of lanes in
value.

Returns
-------
store : Stmt
The corresponding store stmt.
"""
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
return _ffi_api.BufferVStore(self, begin, value) # type: ignore
return _ffi_api.BufferVStore(self, begin, value, predicate) # type: ignore

def scope(self):
"""Return the storage scope associated with this buffer.
Expand Down
14 changes: 11 additions & 3 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,20 +1093,28 @@ class BufferLoad(PrimExprWithOp):
The buffer to be loaded.

indices : List[PrimExpr]
The buffer indices.
The buffer indices to load values from.

span : Optional[Span]
The location of this expression in the source code.

predicate : Optional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be
loaded. The number lanes of the mask must be equal to the number of lanes being loaded.
"""

buffer: Buffer
indices: List[PrimExpr]

def __init__(
self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = None
self,
buffer: Buffer,
indices: List[PrimExpr],
predicate: Optional[PrimExpr] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.BufferLoad, buffer, indices, span # type: ignore
_ffi_api.BufferLoad, buffer, indices, predicate, span # type: ignore
)


Expand Down
9 changes: 8 additions & 1 deletion python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,24 +224,31 @@ class BufferStore(Stmt):
indices : List[PrimExpr]
The indices location to be stored.

predicate : Optional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be
stored. The number lanes of the mask must be equal to the number of lanes in
value.

span : Optional[Span]
The location of the stmt in the source code.
"""

buffer: Buffer
value: PrimExpr
indices: List[PrimExpr]
predicate: Optional[PrimExpr]
span: Optional[Span]

def __init__(
self,
buffer: Buffer,
value: PrimExpr,
indices: List[PrimExpr],
predicate: Optional[PrimExpr] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.BufferStore, buffer, value, indices, span # type: ignore
_ffi_api.BufferStore, buffer, value, indices, predicate, span # type: ignore
)


Expand Down
5 changes: 3 additions & 2 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,16 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
// "T.vscale" and the compile target uses a scalable architecture extension like
// SVE, we can make some assumptions about the value of vscale and iterate over a
// space of pre-defined values to attempt to prove the expression.
Target curr_target = Target::Current();
if (ContainsVscaleCall(simplified)) {
if (TargetHasSVE()) {
if (TargetHasSVE(curr_target)) {
return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues);
}
LOG(WARNING)
<< "The expression contains scalable values. An attempt to prove by substituting "
"with known values of vscale was not performed. This proof currently only supports "
"AArch64 SVE targets, but the target was "
<< Target::Current();
<< curr_target;
}
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ class ConstIntBoundAnalyzer::Impl
return VisitLeftShift(op);
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
} else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE()) {
} else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE(Target::Current())) {
unsigned int max_val =
*std::max_element(kAArch64VScaleValues.begin(), kAArch64VScaleValues.end());
return MakeBound(1, max_val);
Expand Down
3 changes: 1 addition & 2 deletions src/arith/scalable_expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
return can_prove_expr;
}

bool TargetHasSVE() {
Target current_target = Target::Current();
bool TargetHasSVE(Target current_target) {
bool has_sve{false};
if (current_target.defined()) {
has_sve = current_target->GetFeature<Bool>("has_sve").value_or(Bool(false));
Expand Down
4 changes: 3 additions & 1 deletion src/arith/scalable_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <tvm/arith/analyzer.h>
#include <tvm/ir/expr.h>
#include <tvm/target/target.h>

#include <optional>
#include <vector>
Expand Down Expand Up @@ -79,9 +80,10 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr

/*!
* \brief Check whether the compilation target supports SVE
* \param target The target to check.
* \return Whether SVE is supported
*/
bool TargetHasSVE();
bool TargetHasSVE(Target target);

} // namespace arith
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
Expand Down
5 changes: 3 additions & 2 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ Var EnvThread(String thread_tag, DataType dtype) {
return var;
}

void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Optional<PrimExpr> predicate = NullOpt) {
runtime::DataType buffer_dtype = buffer->dtype;
bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector();
bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector();
Expand Down Expand Up @@ -586,7 +587,7 @@ void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
}
value = tvm::cast(lhs_dtype, value);
}
AddToParent(tvm::tir::BufferStore(buffer, value, indices));
AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate));
}

void Prefetch(Buffer buffer, Array<Range> bounds) {
Expand Down
Loading
Loading