Skip to content

Commit

Permalink
[SVE] Add support for representing and creating buffer-level predicates
Browse files Browse the repository at this point in the history
Representation
--------------
This commit extends `BufferLoad` and `BufferStore` to accept a predicate
mask argument indicating which lanes in a vectorized buffer load/store
should be read/written.

As a simple example, we can load all lanes:
```
tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(1, 8))
```

Or disable loading all lanes:
```
tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(0, 8))
```

In TVMScript, buffer loads and stores are currently displayed using a
"short-hand" notation e.g. `A[0:4]`, but there was no clear path for
extending this notation to support predicates. Therefore, a "long-hand"
notation is introduced e.g. `A.load([T.Ramp(0, 1, 4)], predicate=...)`.
The TVMScript printer falls back to the long-hand notation whenever
predicates are specified.

Creation
--------
Buffer-level predication becomes more motivating when combined with the
`tir.get_active_lane_mask` intrinsic. It can be used to mask off lanes
when the vectorized axis is not divisible by the vector length. A
detailed example and rationale can be found in the
[RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0104-scalable-vectors-in-tir.md#predication).

Predicated buffer load/stores are created in the `VectorizeLoop` pass
via `TryPredicateBufferAccesses`. This pass aims to convert block-level
predicates e.g.
```
for i_0 in T.serial(4):
    for i_1 in T.vectorized(4):
        if i_0 * 4 + i_1 < 14:
            B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0
```
to buffer-level predicates, e.g.
```
for i_0 in T.serial(4):
    predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14)
    A_load = T.meta_var(A.load([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate))
    B.store(A_load, [T.Ramp(i_0 * 4, 1, 4)], predicate=predicate)
```
It takes a conservative approach for now, focussing only on expressions
produced by the split scheduling primitive, but more complex expressions
could be supported in the future.

`TryPredicateBufferAccesses` can be explicitly enabled/disabled with the
`tir.enable_buffer_level_predication` pass context option. By default it
will be disabled, unless the target supports SVE, in which case it will
be enabled by default.

Co-authored-by: Elen Kalda <[email protected]>
Co-authored-by: Neil Hickey <[email protected]>

Change-Id: Idde259a7d7e4536f00ed3a1dafedd0a5d24a1593
  • Loading branch information
lhutton1 committed May 2, 2024
1 parent 61c44f9 commit 784a75a
Show file tree
Hide file tree
Showing 34 changed files with 810 additions and 60 deletions.
4 changes: 3 additions & 1 deletion include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,10 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));
* \param buffer The buffer.
* \param value The value to be stored.
* \param indices The indices location to be stored.
* \param predicate A vector mask of int1 values indicating which lanes of a vector are to be
* stored.
*/
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices, PrimExpr predicate);

/*!
* \brief The prefetch hint for a buffer
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,11 +630,14 @@ class BufferLoadNode : public PrimExprNode {
Buffer buffer;
/*! \brief The indices location to be loaded. */
Array<PrimExpr> indices;
/*! \brief The predicate mask for loading values. */
PrimExpr predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("buffer", &buffer);
v->Visit("indices", &indices);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

Expand All @@ -647,6 +650,7 @@ class BufferLoadNode : public PrimExprNode {
hash_reduce(dtype);
hash_reduce(buffer);
hash_reduce(indices);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.BufferLoad";
Expand Down Expand Up @@ -675,7 +679,8 @@ class BufferLoadNode : public PrimExprNode {
*/
class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices,
PrimExpr predicate = PrimExpr(), Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,14 @@ class BufferStoreNode : public StmtNode {
PrimExpr value;
/*! \brief The indices location to be stored. */
Array<PrimExpr> indices;
/*! \brief The predicate mask for storing values. */
PrimExpr predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("value", &value);
v->Visit("indices", &indices);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

Expand All @@ -248,6 +251,7 @@ class BufferStoreNode : public StmtNode {
hash_reduce(buffer);
hash_reduce(value);
hash_reduce(indices);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.BufferStore";
Expand All @@ -261,7 +265,7 @@ class BufferStoreNode : public StmtNode {
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Span span = Span());
PrimExpr predicate = PrimExpr(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ def _updater(data):
return _updater


def create_updater_16_to_17():
"""
Create an update to upgrade json from v0.16 to v0.17
Returns
-------
fupdater : function
The updater function
"""

def _update_predicate_argument(item, nodes):
null_value_idx = 0
null_value = nodes[null_value_idx]
assert str(null_value) == "{'type_key': ''}", f"Expected a null value but got {null_value}"
item["attrs"]["predicate"] = str(null_value_idx)
return item

node_map = {
"tir.BufferLoad": _update_predicate_argument,
"tir.BufferStore": _update_predicate_argument,
}

return create_updater(node_map, "0.16", "0.17")


def create_updater_15_to_16():
"""
Create an update to upgrade json from v0.15 to v0.16
Expand Down Expand Up @@ -316,5 +341,7 @@ def _from_version(data):
data = create_updater({}, "0.14", "0.15")(data)
if _from_version(data).startswith("0.15"):
data = create_updater_15_to_16()(data)
if _from_version(data).startswith("0.16"):
data = create_updater_16_to_17()(data)

return json.dumps(data, indent=2)
6 changes: 5 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,7 @@ def buffer_store(
buffer: Buffer, # pylint: disable=redefined-outer-name
value: PrimExpr,
indices: List[Union[PrimExpr, slice]],
predicate: Optional[PrimExpr] = None,
) -> None:
"""Buffer store node.
Expand All @@ -1278,6 +1279,9 @@ def buffer_store(
indices : List[Union[PrimExpr, slice]]
The indices location to be stored.
predicate : Optional[PrimExpr]
A vector mask of int1 values indicating which lanes of a vector are to be stored.
"""
from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel

Expand All @@ -1298,7 +1302,7 @@ def buffer_store(
if isinstance(value, bool) and buffer.dtype == "bool":
value = IntImm("bool", value)
return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member
buffer, value, expr_indices
buffer, value, expr_indices, predicate
)


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
elif isinstance(res, str):
# Ignore docstrings
pass
elif isinstance(res, tvm.tir.stmt.BufferStore):
T.buffer_store(res.buffer, res.value, res.indices, res.predicate)
else:
self.report_error(node, f"Parsing resulted in unexpected type {type(res)}")

Expand Down
51 changes: 51 additions & 0 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,57 @@ def vstore(self, begin, value):
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
return _ffi_api.BufferVStore(self, begin, value) # type: ignore

def load(self, indices, predicate=None):
"""
Load values at specified indices from buffer.
Longhand notation that can be used for complex buffer load
expressions. For example, when the load involves predication.
Parameters
----------
indices : List[PrimExpr]
The buffer indices to load values from.
predicate : Optional[PrimExpr]
A vector mask of int1 values indicating which lanes of a vector are to be loaded.
Returns
-------
BufferLoad
A buffer load Expr.
"""
from .expr import BufferLoad # pylint: disable=import-outside-toplevel

return BufferLoad(self, indices, predicate)

def store(self, value, indices, predicate=None):
"""
Store given value at the specified indices in the buffer.
Longhand notation that can be used for complex buffer store
statements. For example, when the store involves predication.
Parameters
----------
value : PrimExpr
The value to be stored.
indices : List[PrimExpr]
The buffer indices to store values to.
predicate : Optional[PrimExpr]
A vector mask of int1 values indicating which lanes of a vector are to be stored.
Returns
-------
BufferStore
A buffer store Stmt.
"""
from .stmt import BufferStore # pylint: disable=import-outside-toplevel

return BufferStore(self, value, indices, predicate)

def scope(self):
"""Return the storage scope associated with this buffer.
Returns
Expand Down
13 changes: 10 additions & 3 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,20 +1093,27 @@ class BufferLoad(PrimExprWithOp):
The buffer to be loaded.
indices : List[PrimExpr]
The buffer indices.
The buffer indices to load values from.
span : Optional[Span]
The location of this expression in the source code.
predicate : Optional[PrimExpr]
A vector mask of int1 values indicating which lanes of a vector are to be loaded.
"""

buffer: Buffer
indices: List[PrimExpr]

def __init__(
self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = None
self,
buffer: Buffer,
indices: List[PrimExpr],
predicate: Optional[PrimExpr] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.BufferLoad, buffer, indices, span # type: ignore
_ffi_api.BufferLoad, buffer, indices, predicate, span # type: ignore
)


Expand Down
7 changes: 6 additions & 1 deletion python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,24 +224,29 @@ class BufferStore(Stmt):
indices : List[PrimExpr]
The indices location to be stored.
predicate : Optional[PrimExpr]
A vector mask of int1 values indicating which lanes of a vector are to be stored.
span : Optional[Span]
The location of the stmt in the source code.
"""

buffer: Buffer
value: PrimExpr
indices: List[PrimExpr]
predicate: Optional[PrimExpr]
span: Optional[Span]

def __init__(
self,
buffer: Buffer,
value: PrimExpr,
indices: List[PrimExpr],
predicate: Optional[PrimExpr] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.BufferStore, buffer, value, indices, span # type: ignore
_ffi_api.BufferStore, buffer, value, indices, predicate, span # type: ignore
)


Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
Expand Down
5 changes: 3 additions & 2 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ Var EnvThread(String thread_tag, DataType dtype) {
return var;
}

void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
PrimExpr predicate = PrimExpr()) {
runtime::DataType buffer_dtype = buffer->dtype;
bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector();
bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector();
Expand Down Expand Up @@ -586,7 +587,7 @@ void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
}
value = tvm::cast(lhs_dtype, value);
}
AddToParent(tvm::tir::BufferStore(buffer, value, indices));
AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate));
}

void Prefetch(Buffer buffer, Array<Range> bounds) {
Expand Down
23 changes: 21 additions & 2 deletions src/script/printer/tir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,33 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferStore>( //
"", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc buffer = d->AsDoc<ExprDoc>(store->buffer, p->Attr("buffer"));
return AssignDoc(/*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)],
/*rhs=*/d->AsDoc<ExprDoc>(store->value, p->Attr("value")), NullOpt);
ExprDoc value = d->AsDoc<ExprDoc>(store->value, p->Attr("value"));

// Use .store(...) syntax when there is a predicate
if (store->predicate.defined()) {
ExprDoc indices = d->AsDoc<ExprDoc>(store->indices, p->Attr("indices"));
ExprDoc predicate = d->AsDoc<ExprDoc>(store->predicate, p->Attr("predicate"));
return ExprStmtDoc(
buffer->Attr("store")->Call({value, indices}, {"predicate"}, {predicate}));
}

return AssignDoc(
/*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)],
/*rhs=*/value, NullOpt);
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferLoad>( //
"", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc buffer = d->AsDoc<ExprDoc>(load->buffer, p->Attr("buffer"));

// Use .load(...) syntax when there is a predicate
if (load->predicate.defined()) {
ExprDoc indices = d->AsDoc<ExprDoc>(load->indices, p->Attr("indices"));
ExprDoc predicate = d->AsDoc<ExprDoc>(load->predicate, p->Attr("predicate"));
return buffer->Attr("load")->Call({indices}, {"predicate"}, {predicate});
}

return buffer[BufferIndices(load->indices, p->Attr("indices"), d)];
});

Expand Down
Loading

0 comments on commit 784a75a

Please sign in to comment.