Skip to content

Commit

Permalink
[SVE] Add support for representing and creating buffer-level predicat…
Browse files Browse the repository at this point in the history
…es (#16966)

* [SVE] Add support for representing and creating buffer-level predicates

Representation
--------------
This commit extends `BufferLoad` and `BufferStore` to accept a predicate
mask argument indicating which lanes in a vectorized buffer load/store
should be read/written.

As a simple example, we can load all lanes:
```
tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(1, 8))
```

Or disable loading all lanes:
```
tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(0, 8))
```

In TVMScript, buffer loads and stores are currently displayed using a
"short-hand" notation e.g. `A[0:4]`, but there was no clear path for
extending this notation to support predicates. Therefore, a "long-hand"
notation is introduced e.g. `A.load([T.Ramp(0, 1, 4)], predicate=...)`.
The TVMScript printer falls back to the long-hand notation whenever
predicates are specified.

Creation
--------
Buffer-level predication becomes more motivating when combined with the
`tir.get_active_lane_mask` intrinsic. It can be used to mask off lanes
when the vectorized axis is not divisible by the vector length. A
detailed example and rationale can be found in the
[RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0104-scalable-vectors-in-tir.md#predication).

Predicated buffer load/stores are created in the `VectorizeLoop` pass
via `TryPredicateBufferAccesses`. This pass aims to convert block-level
predicates e.g.
```
for i_0 in T.serial(4):
    for i_1 in T.vectorized(4):
        if i_0 * 4 + i_1 < 14:
            B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0
```
to buffer-level predicates, e.g.
```
for i_0 in T.serial(4):
    predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14)
    A_load = T.meta_var(A.load([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate))
    B.store(A_load, [T.Ramp(i_0 * 4, 1, 4)], predicate=predicate)
```
It takes a conservative approach for now, focussing only on expressions
produced by the split scheduling primitive, but more complex expressions
could be supported in the future.

`TryPredicateBufferAccesses` can be explicitly enabled/disabled with the
`tir.enable_buffer_level_predication` pass context option. By default it
will be disabled, unless the target supports SVE, in which case it will
be enabled by default.

Co-authored-by: Elen Kalda <[email protected]>
Co-authored-by: Neil Hickey <[email protected]>

Change-Id: Idde259a7d7e4536f00ed3a1dafedd0a5d24a1593

* Fix lint and correct test config option name

Change-Id: I864475c3d03e9b426ce5ef987989216d57f3e019

* Address review comments

This includes:
* Taking into account possibility of target being overridden in
  the vectorize pass.
* Predicate PrimExpr -> Optional<PrimExpr>
* Checking that predicate is not used for any target that doesn't
  support it.
* Use vload/vstore API as opposed to load/store
* int1 mask -> uint1 mask for boolean representation. This is converted
  to int1 in the LLVM backend.

Change-Id: I4da0705352e321f6be6333a5bb777caa6a6ca9ef

* Fix lint

Change-Id: Idd3f3593fe524f3444487c520d947dfd53386db0

* Fix some failing tests

* vload/vstore updates that were missed previously
* int1 -> bool updates
* fix gpu target tests

Fixes a test and updates comments referencing old load/store api

Change-Id: I26a0c480d2dedee442ca0116909a7751d1dfa9ac

* Address comments

- Correct doc strings
- Correct typo in error message
- Add some additional checks for BufferLoad

Change-Id: Ie25563d569c0ed729ac915a6ba3a724a9e191014

* Account for buffer lanes in predicate lane check

Change-Id: I821210665e36c26bfa37fc9ed380b5d03c9e816e
  • Loading branch information
lhutton1 authored May 28, 2024
1 parent b598f28 commit 20d8c53
Show file tree
Hide file tree
Showing 45 changed files with 1,196 additions and 108 deletions.
5 changes: 4 additions & 1 deletion include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,11 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));
* \param buffer The buffer.
* \param value The value to be stored.
* \param indices The indices location to be stored.
* \param predicate A vector mask of boolean values indicating which lanes of a vector are to be
* stored. The number lanes of the mask must be equal to the number of lanes in value.
*/
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Optional<PrimExpr> predicate);

/*!
* \brief The prefetch hint for a buffer
Expand Down
10 changes: 8 additions & 2 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,20 @@ class Buffer : public ObjectRef {
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
* \param predicate A vector mask of boolean values indicating which lanes of a vector are to be
* loaded. The number lanes of the mask must be equal to the number of lanes in being loaded.
*/
TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype,
Optional<PrimExpr> predicate = NullOpt) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
* \param value The value to be stored.
* \param predicate A vector mask of boolean values indicating which lanes of a vector are to be
* stored. The number lanes of the mask must be equal to the number of lanes in value.
*/
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value,
Optional<PrimExpr> predicate = NullOpt) const;

/*!
* \brief Get a flattened version of the buffer
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,11 +630,14 @@ class BufferLoadNode : public PrimExprNode {
Buffer buffer;
/*! \brief The indices location to be loaded. */
Array<PrimExpr> indices;
/*! \brief The predicate mask for loading values. */
Optional<PrimExpr> predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("buffer", &buffer);
v->Visit("indices", &indices);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

Expand All @@ -647,6 +650,7 @@ class BufferLoadNode : public PrimExprNode {
hash_reduce(dtype);
hash_reduce(buffer);
hash_reduce(indices);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.BufferLoad";
Expand Down Expand Up @@ -675,7 +679,8 @@ class BufferLoadNode : public PrimExprNode {
*/
class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices,
Optional<PrimExpr> predicate = NullOpt, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,14 @@ class BufferStoreNode : public StmtNode {
PrimExpr value;
/*! \brief The indices location to be stored. */
Array<PrimExpr> indices;
/*! \brief The predicate mask for storing values. */
Optional<PrimExpr> predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("value", &value);
v->Visit("indices", &indices);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

Expand All @@ -248,6 +251,7 @@ class BufferStoreNode : public StmtNode {
hash_reduce(buffer);
hash_reduce(value);
hash_reduce(indices);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.BufferStore";
Expand All @@ -261,7 +265,7 @@ class BufferStoreNode : public StmtNode {
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Span span = Span());
Optional<PrimExpr> predicate = NullOpt, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ def _updater(data):
return _updater


def create_updater_16_to_17():
"""
Create an update to upgrade json from v0.16 to v0.17
Returns
-------
fupdater : function
The updater function
"""

def _update_predicate_argument(item, nodes):
null_value_idx = 0
null_value = nodes[null_value_idx]
assert str(null_value) == "{'type_key': ''}", f"Expected a null value but got {null_value}"
item["attrs"]["predicate"] = str(null_value_idx)
return item

node_map = {
"tir.BufferLoad": _update_predicate_argument,
"tir.BufferStore": _update_predicate_argument,
}

return create_updater(node_map, "0.16", "0.17")


def create_updater_15_to_16():
"""
Create an update to upgrade json from v0.15 to v0.16
Expand Down Expand Up @@ -316,5 +341,7 @@ def _from_version(data):
data = create_updater({}, "0.14", "0.15")(data)
if _from_version(data).startswith("0.15"):
data = create_updater_15_to_16()(data)
if _from_version(data).startswith("0.16"):
data = create_updater_16_to_17()(data)

return json.dumps(data, indent=2)
8 changes: 7 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,7 @@ def buffer_store(
buffer: Buffer, # pylint: disable=redefined-outer-name
value: PrimExpr,
indices: List[Union[PrimExpr, slice]],
predicate: Optional[PrimExpr] = None,
) -> None:
"""Buffer store node.
Expand All @@ -1278,6 +1279,11 @@ def buffer_store(
indices : List[Union[PrimExpr, slice]]
The indices location to be stored.
predicate : Optional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be
stored. The number lanes of the mask must be equal to the number of lanes in
value.
"""
from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel

Expand All @@ -1298,7 +1304,7 @@ def buffer_store(
if isinstance(value, bool) and buffer.dtype == "bool":
value = IntImm("bool", value)
return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member
buffer, value, expr_indices
buffer, value, expr_indices, predicate
)


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
elif isinstance(res, str):
# Ignore docstrings
pass
elif isinstance(res, tvm.tir.stmt.BufferStore):
T.buffer_store(res.buffer, res.value, res.indices, res.predicate)
else:
self.report_error(node, f"Parsing resulted in unexpected type {type(res)}")

Expand Down
17 changes: 13 additions & 4 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0,
self, access_mask, ptr_type, content_lanes, offset, extent # type: ignore
)

def vload(self, begin, dtype=None):
def vload(self, begin, dtype=None, predicate=None):
"""Generate an Expr that loads dtype from begin index.
Parameters
Expand All @@ -113,16 +113,20 @@ def vload(self, begin, dtype=None):
The data type to be loaded,
can be vector type which have lanes that is multiple of Buffer.dtype
predicate : Optional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be
loaded. The number lanes of the mask must be equal to the number of lanes being loaded.
Returns
-------
load : Expr
The corresponding load expression.
"""
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
dtype = dtype if dtype else self.dtype
return _ffi_api.BufferVLoad(self, begin, dtype) # type: ignore
return _ffi_api.BufferVLoad(self, begin, dtype, predicate) # type: ignore

def vstore(self, begin, value):
def vstore(self, begin, value, predicate=None):
"""Generate a Stmt that store value into begin index.
Parameters
Expand All @@ -133,13 +137,18 @@ def vstore(self, begin, value):
value : Expr
The value to be stored.
predicate : Optional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be
stored. The number lanes of the mask must be equal to the number of lanes in
value.
Returns
-------
store : Stmt
The corresponding store stmt.
"""
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
return _ffi_api.BufferVStore(self, begin, value) # type: ignore
return _ffi_api.BufferVStore(self, begin, value, predicate) # type: ignore

def scope(self):
"""Return the storage scope associated with this buffer.
Expand Down
14 changes: 11 additions & 3 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,20 +1093,28 @@ class BufferLoad(PrimExprWithOp):
The buffer to be loaded.
indices : List[PrimExpr]
The buffer indices.
The buffer indices to load values from.
span : Optional[Span]
The location of this expression in the source code.
predicate : Optional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be
loaded. The number lanes of the mask must be equal to the number of lanes being loaded.
"""

buffer: Buffer
indices: List[PrimExpr]

def __init__(
self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = None
self,
buffer: Buffer,
indices: List[PrimExpr],
predicate: Optional[PrimExpr] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.BufferLoad, buffer, indices, span # type: ignore
_ffi_api.BufferLoad, buffer, indices, predicate, span # type: ignore
)


Expand Down
9 changes: 8 additions & 1 deletion python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,24 +224,31 @@ class BufferStore(Stmt):
indices : List[PrimExpr]
The indices location to be stored.
predicate : Optional[PrimExpr]
A vector mask of boolean values indicating which lanes of a vector are to be
stored. The number lanes of the mask must be equal to the number of lanes in
value.
span : Optional[Span]
The location of the stmt in the source code.
"""

buffer: Buffer
value: PrimExpr
indices: List[PrimExpr]
predicate: Optional[PrimExpr]
span: Optional[Span]

def __init__(
self,
buffer: Buffer,
value: PrimExpr,
indices: List[PrimExpr],
predicate: Optional[PrimExpr] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.BufferStore, buffer, value, indices, span # type: ignore
_ffi_api.BufferStore, buffer, value, indices, predicate, span # type: ignore
)


Expand Down
5 changes: 3 additions & 2 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,16 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
// "T.vscale" and the compile target uses a scalable architecture extension like
// SVE, we can make some assumptions about the value of vscale and iterate over a
// space of pre-defined values to attempt to prove the expression.
Target curr_target = Target::Current();
if (ContainsVscaleCall(simplified)) {
if (TargetHasSVE()) {
if (TargetHasSVE(curr_target)) {
return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues);
}
LOG(WARNING)
<< "The expression contains scalable values. An attempt to prove by substituting "
"with known values of vscale was not performed. This proof currently only supports "
"AArch64 SVE targets, but the target was "
<< Target::Current();
<< curr_target;
}
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ class ConstIntBoundAnalyzer::Impl
return VisitLeftShift(op);
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
} else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE()) {
} else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE(Target::Current())) {
unsigned int max_val =
*std::max_element(kAArch64VScaleValues.begin(), kAArch64VScaleValues.end());
return MakeBound(1, max_val);
Expand Down
3 changes: 1 addition & 2 deletions src/arith/scalable_expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
return can_prove_expr;
}

bool TargetHasSVE() {
Target current_target = Target::Current();
bool TargetHasSVE(Target current_target) {
bool has_sve{false};
if (current_target.defined()) {
has_sve = current_target->GetFeature<Bool>("has_sve").value_or(Bool(false));
Expand Down
4 changes: 3 additions & 1 deletion src/arith/scalable_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <tvm/arith/analyzer.h>
#include <tvm/ir/expr.h>
#include <tvm/target/target.h>

#include <optional>
#include <vector>
Expand Down Expand Up @@ -79,9 +80,10 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr

/*!
* \brief Check whether the compilation target supports SVE
* \param target The target to check.
* \return Whether SVE is supported
*/
bool TargetHasSVE();
bool TargetHasSVE(Target target);

} // namespace arith
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
Expand Down
5 changes: 3 additions & 2 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ Var EnvThread(String thread_tag, DataType dtype) {
return var;
}

void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Optional<PrimExpr> predicate = NullOpt) {
runtime::DataType buffer_dtype = buffer->dtype;
bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector();
bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector();
Expand Down Expand Up @@ -586,7 +587,7 @@ void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
}
value = tvm::cast(lhs_dtype, value);
}
AddToParent(tvm::tir::BufferStore(buffer, value, indices));
AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate));
}

void Prefetch(Buffer buffer, Array<Range> bounds) {
Expand Down
Loading

0 comments on commit 20d8c53

Please sign in to comment.