-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SVE] Add support for representing and creating buffer-level predicates #16966
[SVE] Add support for representing and creating buffer-level predicates #16966
Conversation
784a75a
to
efb057b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the functionality overall. I have a couple of requested changes, which mostly fall under the categories below.
- TIR edge cases, when
Target::Current()
may be overridden. - Using
Optional<PrimExpr>
instead ofPrimExpr
. - Validating that
!predicate.defined()
in any target that does not support it.
src/target/llvm/codegen_llvm.cc
Outdated
} else { | ||
load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), | ||
is_volatile); | ||
} | ||
#elif TVM_LLVM_VERSION >= 80 | ||
auto load = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR only adds CreateMaskedLoad
when TVM_LLVM_VERSION >= 110
. If somebody is using an older version of LLVM, it would silently ignore the predicate for the load/store. We should either support it, or throw an exception.
It looks like CreateMaskedLoad
has been supported in LLVM since this commit, so I'd lean toward adding it in the other #elif
branches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, thanks. I've added support for previous versions of LLVM. I've checked the build with the following versions of LLVM: 7*, 8*, 9*, 10, 11, 12, 13, 17
* fails to build due to other seemingly unrelated errors
@@ -700,5 +700,31 @@ def before(a: T.handle): | |||
assert "get.active.lane.mask" in ll | |||
|
|||
|
|||
@pytest.mark.skipif( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This validates that we have the correct output for architectures that support SVE, but it doesn't test the behavior of other targets that do not (yet) support predicated loads/stores. While the VectorizeLoop
pass would only insert a predicated load/store for targets that support it, the predicated load/store could still be generated in hand-written kernels, or through other transforms in the future.
Can we add a test, parametrized over each target tested in CI, which attempts to compile a PrimFunc containing predicated loads/stores? For each target that supports sve, tvm.build
should compile without error, and for each target that does not, tvm.build
should raise an exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - I wasn't able to check all targets locally, so I'm hoping they pass CI here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking the time to review @Lunderberg, I'm working through the comments but wanted to ask a couple of questions / respond to some of the comments before I continue
src/tir/transforms/vectorize_loop.cc
Outdated
@@ -72,6 +72,126 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { | |||
return Broadcast(e, CreateNewLanes(is_scalable, lanes)); | |||
} | |||
|
|||
bool EnableBufferLevelPredication() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that makes sense. Do you know if there is any general infrastructure for keep track of the current target (which takes into account this override functionality) from within a pass? Otherwise I feel we will be duplicating this functionality in multiple places. I was thinking something similar to: LexicalOnDeviceMixin
(assuming I understood it correctly)
No problem, and thank you on the revisions! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @lhutton1 for all the work on this (it's a lot of work) and @Lunderberg for constructive reviews! I've only got some minor nits.
834ba44
to
d8795a0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @lhutton1 LGTM! 🚀
@tvm-bot rerun |
Representation -------------- This commit extends `BufferLoad` and `BufferStore` to accept a predicate mask argument indicating which lanes in a vectorized buffer load/store should be read/written. As a simple example, we can load all lanes: ``` tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(1, 8)) ``` Or disable loading all lanes: ``` tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(0, 8)) ``` In TVMScript, buffer loads and stores are currently displayed using a "short-hand" notation e.g. `A[0:4]`, but there was no clear path for extending this notation to support predicates. Therefore, a "long-hand" notation is introduced e.g. `A.load([T.Ramp(0, 1, 4)], predicate=...)`. The TVMScript printer falls back to the long-hand notation whenever predicates are specified. Creation -------- Buffer-level predication becomes more motivating when combined with the `tir.get_active_lane_mask` intrinsic. It can be used to mask off lanes when the vectorized axis is not divisible by the vector length. A detailed example and rationale can be found in the [RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0104-scalable-vectors-in-tir.md#predication). Predicated buffer load/stores are created in the `VectorizeLoop` pass via `TryPredicateBufferAccesses`. This pass aims to convert block-level predicates e.g. ``` for i_0 in T.serial(4): for i_1 in T.vectorized(4): if i_0 * 4 + i_1 < 14: B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 ``` to buffer-level predicates, e.g. ``` for i_0 in T.serial(4): predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14) A_load = T.meta_var(A.load([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate)) B.store(A_load, [T.Ramp(i_0 * 4, 1, 4)], predicate=predicate) ``` It takes a conservative approach for now, focussing only on expressions produced by the split scheduling primitive, but more complex expressions could be supported in the future. `TryPredicateBufferAccesses` can be explicitly enabled/disabled with the `tir.enable_buffer_level_predication` pass context option. By default it will be disabled, unless the target supports SVE, in which case it will be enabled by default. Co-authored-by: Elen Kalda <[email protected]> Co-authored-by: Neil Hickey <[email protected]> Change-Id: Idde259a7d7e4536f00ed3a1dafedd0a5d24a1593
Change-Id: I864475c3d03e9b426ce5ef987989216d57f3e019
This includes: * Taking into account possibility of target being overridden in the vectorize pass. * Predicate PrimExpr -> Optional<PrimExpr> * Checking that predicate is not used for any target that doesn't support it. * Use vload/vstore API as opposed to load/store * int1 mask -> uint1 mask for boolean representation. This is converted to int1 in the LLVM backend. Change-Id: I4da0705352e321f6be6333a5bb777caa6a6ca9ef
* vload/vstore updates that were missed previously * int1 -> bool updates * fix gpu target tests Fixes a test and updates comments referencing old load/store api Change-Id: I26a0c480d2dedee442ca0116909a7751d1dfa9ac
- Correct doc strings - Correct typo in error message - Add some additional checks for BufferLoad Change-Id: Ie25563d569c0ed729ac915a6ba3a724a9e191014
Change-Id: I821210665e36c26bfa37fc9ed380b5d03c9e816e
d8795a0
to
cbd2e48
Compare
friendly ping @Lunderberg if you have some free time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This patch is waiting here for more than a week now. I think we should merge it before it starts creating merge-conflicts, so I'm going to merge it now.
@Lunderberg, I'm sure @lhutton1 will be happy to take any other outstanding comment in follow-up patches, please reply when you can. Thanks.
Replying taking a while to come back. Happy to take comments in follow-up patches.
if i_0 * 4 * T.vscale() + i_1 < 14: | ||
B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 | ||
|
||
with tvm.target.Target(target): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This style moving foward, instead, attach the target attribute to the PrimFunc itself
Apologies for not getting back to the review, and thank you for making the changes! |
Representation
This commit extends
BufferLoad
andBufferStore
to accept a predicate mask argument indicating which lanes in a vectorized buffer load/store should be read/written.As a simple example, we can load all lanes:
Or disable loading all lanes:
In TVMScript, buffer loads and stores are currently displayed using a "short-hand" notation e.g.
A[0:4]
, but there was no clear path for extending this notation to support predicates. Therefore, the vload/vstore notation is used e.g.A.vload([T.Ramp(0, 1, 4)], predicate=...)
. The TVMScript printer falls back to the vload/vstore notation whenever predicates are specified.Creation
Buffer-level predication becomes more motivating when combined with the
tir.get_active_lane_mask
intrinsic. It can be used to mask off lanes when the vectorized axis is not divisible by the vector length. A detailed example and rationale can be found in the RFC.Predicated buffer load/stores are created in the
VectorizeLoop
pass viaTryPredicateBufferAccesses
. This pass aims to convert block-level predicates e.g.to buffer-level predicates, e.g.
It takes a conservative approach for now, focussing only on expressions produced by the split scheduling primitive, but more complex expressions could be supported in the future.
TryPredicateBufferAccesses
can be explicitly enabled/disabled with thetir.enable_buffer_level_predication
pass context option. By default it will be disabled, unless the target supports SVE, in which case it will be enabled by default.Note: this commit depends on #16965, so also contains the contents of #16965.Co-authored-by: Elen Kalda [email protected]
Co-authored-by: Neil Hickey [email protected]