-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for 3-5D TMA to allow loading non-matmul operands #5207
Conversation
third_party/nvidia/backend/driver.c
Outdated
// For now, we do not swizzle for higher ranks. Enabling swizzling in TMA | ||
// implies hasLeadingOffset = true in SMEM encoding, which is currently not | ||
// supported for higher rank TMA copies. This convention needs to be in sync | ||
// with the TMA lowering pass in codegen. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ThomasRaoux This is my takeaway from our discussion yesterday. Let me know if this is ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is correct but this convention makes me bit nervous as we won't be able to handle the case where we 3D inputs for a batch matmul kind of cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With or without this PR, the underlying limitation that prevents 3D TMA with swizzling for matmul inputs continues to exist. So I would say this changes just make the limitation explicit in the API temporarily until the issue is fixed, at which point we can remove this convention.
For batched matmul kind of workloads, the flattening of a higher dim tensor into 2D is straightforward. So there is still an escape hatch.
Overall, I think this PR won't make the situation any worse. Maybe the new TMA representation you mentioned would solve all of those issues. But while we wait for that, it would be good to enable more use cases for TMA within Triton - it is an "experimental" feature, after all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well the case that I think is interesting is to write a batch matmul case where the global tensor is 3D but each block loads a 2D tensor and compute matmul on it.
So I would say this changes just make the limitation explicit in the API temporarily until the issue is fixed, at which point we can remove this convention.
even then you wouldn't want swizzling for this case right?
For batched matmul kind of workloads, the flattening of a higher dim tensor into 2D is straightforward. So there is still an escape hatch.
That breaks the bound checking part right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
even then you wouldn't want swizzling for this case right?
Oh sorry if there was misunderstanding. I believe 2D-5D TMA should be treated equally, and I do hope that we can remove this restriction. Whether or not swizzing would be beneficial for my use case is a separate question I need to investigate in the future. Right now my inner-most axis size is 16B so no swizzling would be applied anyway. But I can tweak the sizes of the inner-most two dims, to make the inner-most axis wider and apply swizzling if I want to.
That breaks the bound checking part right?
hmm I haven't thought about that but indeed I don't see how OOB check can work if some dims are flattened (if possible at all).
Maybe the device-side tensor-map creation can be used? After we get the batch id (or group id for grouped gemm), we can use 2D TMA. I'm not sure if that's supported now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to get clarified on your concern. I thought this PR would not have any negative implications, since higher-rank TMA with swizzling doesn't work anyway. But are you saying that, one important special case of 3D TMA, where the actual load is 2D (since one of copy dim sizes is always one) is supposed to work with the current impl, but my change would disallow the swizzling for that case as well?
If that's the case, the only workaround, without a proper fix, would be to make swizzling a parameter for higher-rank TMA that the user provide. By default, we don't swizzle. We also need to pass the same swizzling param to tl.experimental_descriptor_load(...)
to make the codegen and runtime code in sync.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another idea would be to base the decision to enable swizzling not on the rank of the global tensor but the "effective rank" of the box, where by "effective rank" I mean a rank after removing size-1 dims.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code changes look reasonable to me but it is definitely showing the limitation of the convention we have around swizzling.
@peterbell10 is working on the next representation for TMA and this representation will become deprecated or kept for experimental.
If this is going to block you I think it is fine to go with this as this is in the experimental path but if there is an alternative we should look at how to fit that in the new path.
third_party/nvidia/backend/driver.c
Outdated
// For now, we do not swizzle for higher ranks. Enabling swizzling in TMA | ||
// implies hasLeadingOffset = true in SMEM encoding, which is currently not | ||
// supported for higher rank TMA copies. This convention needs to be in sync | ||
// with the TMA lowering pass in codegen. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is correct but this convention makes me bit nervous as we won't be able to handle the case where we 3D inputs for a batch matmul kind of cases
Hi @ThomasRaoux, I want to come back to the discussion on this PR. To address your concern, I made a change according to my comment #5207 (comment) and added a test for dot3d where we use the 3D TMA API but the actual load is 2D. I'm now allowing swizzling to be applied in such case. |
q = tl._experimental_descriptor_load(q_desc, [batch_id, startm, 0], [1, BLOCK_M, BLOCK_K], tl.float16) | ||
k = tl._experimental_descriptor_load(k_desc, [batch_id, 0, startn], [1, BLOCK_K, BLOCK_N], tl.float16) | ||
qk = tl.dot(q.reshape(BLOCK_M, BLOCK_K), k.reshape(BLOCK_K, BLOCK_N), out_dtype=tl.float32) | ||
o_ptrs = o_ptr + batch_id * stride_ob + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This work but this would definitely not generate efficient code right? My point was that this prevents us from writing an efficient batch matmul
We discussed and agreed offline that we are not going to add more changes to the experimental API. |
Mostly a mechanical change to expand the experimental support for TMA, which is currently limited to 1-2D. I have a use case for TMA to load 3D or 4D tensor which encodes blocked scales from MXFP in a specialized layout.
Swizzling is disabled for higher rank TMA, to set
hasLeadingOffset = false
for the dst SMEM allocated in TMA lowering. The new unittest fails if swizzling is enabled for TMA andhasLeadingOffset = true
. I believe this is simply due to implementation limitations, so I hope we can enable swizziling for higher rank TMA in the future.cc @ThomasRaoux @mbrookhart @csullivan
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsFILL THIS IN
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)