Skip to content

Commit

Permalink
[LinearLayouts] Faster pext algorithm (#5621)
Browse files Browse the repository at this point in the history
We also skip the LinearLayout test for HIP as it's currently failing.

Regarding the use of `getWarpSize` and `getNumWarpsPerCTA`, which are
not correct for LinearLayouts with broadcasting as noted in
#5617, we found almost all the
uses are in AMD land. Changing these into calling the functions that act
on the module is tricky, as the module is not currently accessible at
the caller site in most of them. As such, we leave this refactor up to
AMD folks.
  • Loading branch information
lezcano authored Jan 15, 2025
1 parent e7e9b3d commit 9895a1f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
33 changes: 22 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0
#include <intrin.h>

static int __builtin_clz(unsigned x) {
unsigned long r;
_BitScanReverse(&r, x);
return static_cast<int>(r);
}

static int __builtin_ctz(unsigned x) {
unsigned long r;
_BitScanForward(&r, x);
return static_cast<int>(r);
}

#endif

namespace mlir {
Expand Down Expand Up @@ -601,18 +608,22 @@ Value pext_i32(RewriterBase &rewriter, Location loc, Value a, uint32_t mask) {
if (mask == 0xFFFFFFFF)
return a;

// We implement a blocked algorithm to avoid generating too many instructions
// Implements the blocked algorithm from
// https://forums.developer.nvidia.com/t/pdep-and-pext-functionality-for-cuda/270973
uint32_t mskConst = mask;
uint32_t extcnt = 0;
Value result = i32_val(0);
int resultPos = 0;
while (mask) {
int start = __builtin_ctz(mask);
int width = __builtin_ctz(~(mask >> start));
Value shifted = lshr(a, i32_val(start));
Value widthMask = i32_val(((1u << width) - 1));
Value blockVal = and_(shifted, widthMask);
result = or_(result, shl(blockVal, i32_val(resultPos)));
resultPos += width;
mask &= ~(((1u << width) - 1) << start);
while (mskConst) {
uint32_t oldmsk = mskConst;
uint32_t bitgrplsb = mskConst & (-mskConst);
mskConst &= bitgrplsb + mskConst;
uint32_t bitgrp = mskConst ^ oldmsk;
uint32_t lsbpos = 31 - __builtin_clz(bitgrplsb);
// like popcount for a number 0..01..1..0 but portable
uint32_t grplen = __builtin_ctz(~(bitgrp >> lsbpos));
uint32_t shift = lsbpos - extcnt;
extcnt += grplen;
result = or_(result, lshr(and_(i32_val(bitgrp), a), i32_val(shift)));
}
return result;
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,8 @@ class DecomposeScaledBlocked
mmaEnc.getInstrShape()[versionMajor == 3
? 0
: mmaEnc.getInstrShape().size() - 2];
auto warpSize = getWarpSize(newAEncoding);
auto mod = scaledDotOp->getParentOfType<ModuleOp>();
int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
assert(instrShapeM <= warpSize);
// Necessary choice to leave all the scales of the tile in that given warp
auto threadsPerWarp =
Expand Down
2 changes: 2 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2793,6 +2793,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
pytest.skip("Skipping test because it runs out of shared memory")
if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024:
pytest.skip("Skipping sum reduction on float16 due to accuracy issues")
if is_hip() and isinstance(src_layout, LinearLayout):
pytest.skip("FIXME: LinearLayout not supported on HIP")

if isinstance(src_layout, MmaLayout) and src_layout.version == 3:
src_layout[2] = 16 if dtype_str == "float16" else 8
Expand Down

0 comments on commit 9895a1f

Please sign in to comment.