Skip to content

Commit

Permalink
[BACKEND] Use LL to simplify redundant elements check and fix related…
Browse files Browse the repository at this point in the history
… issues (#5225)
  • Loading branch information
Jokeren authored Nov 22, 2024
1 parent 4330372 commit e558838
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 104 deletions.
21 changes: 5 additions & 16 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5436,21 +5436,11 @@ def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
pytest.skip("Skip testing MMAv3 on devices with CC < 9")

num_warps = np.cumprod(src_layout.warps_per_cta)[-1]
# TODO(Keren): Remove the intermediate layout once we have resolved the redundantDataMask issue for WGMMA
warps_per_cta = src_layout.warps_per_cta
interm = BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [warps_per_cta[0], warps_per_cta[1]], [0, 1], [1, 1],
[1, 1], [0, 1])

def do_test(src_layout, dst_layout):
layouts = f"""
#src = {src_layout}
#dst = {dst_layout}
#interm = {interm}
"""

conversion = f"""
%12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
%13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
"""

ir = layouts + f"""
Expand All @@ -5460,6 +5450,7 @@ def do_test(src_layout, dst_layout):
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
%3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #dst>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src>
Expand All @@ -5468,12 +5459,10 @@ def do_test(src_layout, dst_layout):
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<f16>, #src>
%3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #interm>
""" + conversion + f"""
%15 = triton_gpu.convert_layout %12 : tensor<{M}x{N}xi32, #dst> -> tensor<{M}x{N}xi32, #interm>
%16 = triton_gpu.convert_layout %13 : tensor<{M}x{N}xf16, #dst> -> tensor<{M}x{N}xf16, #interm>
%17 = tt.addptr %3, %15 : tensor<{M}x{N}x!tt.ptr<f16>, #interm>, tensor<{M}x{N}xi32, #interm>
tt.store %17, %16 : tensor<{M}x{N}x!tt.ptr<f16>, #interm>
%12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
%13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>, tensor<{M}x{N}xi32, #dst>
tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>
tt.return
}}
}}
Expand Down
153 changes: 65 additions & 88 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

using namespace mlir;
Expand All @@ -24,87 +25,57 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
namespace {

// Return the mask for the unique data accessed by given tensor type.
// Used to mask out the redundant data accessed by threads.
Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
Location loc, const NVIDIA::TargetInfo &targetInfo) {
// NOTE: Redundant memory load is allowed in triton, but redundant memory store
// is not allowed.
// mask = true: thread can write
// mask = false: thread should not write
Value getRedundantDataMask(ModuleOp moduleOp, Type valueTy,
ConversionPatternRewriter &rewriter, Location loc,
int regIdx, const NVIDIA::TargetInfo &targetInfo) {
auto ctx = moduleOp.getContext();
auto tensorTy = dyn_cast<RankedTensorType>(valueTy);
Value mask = int_val(1, 1);
auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
auto tid = tid_val();
auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc);
auto mask = true_val();
auto kReg = str_attr("register");
auto kLane = str_attr("lane");
auto kWarp = str_attr("warp");
auto kBlock = str_attr("block");
if (tensorTy) {
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
unsigned rank = shape.size();
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto threadOrder = triton::gpu::getThreadOrder(layout);
SmallVector<unsigned> warpOrder(rank);
if (auto enc = dyn_cast<DotOperandEncodingAttr>(layout)) {
warpOrder =
triton::gpu::getMatrixOrder(rank, /*rowMajor=*/enc.getOpIdx() == 1);
auto layout = tensorTy.getEncoding();
auto ll = triton::gpu::toLinearLayout(shape, layout);
assert(ll.has_value() && "Failed to convert layout to linear layout");
auto freeVariableMasks = ll->getFreeVariableMasks();
auto regMasks = freeVariableMasks[kReg];
if (regMasks & regIdx) {
// Step 1: check register redundancy
mask = false_val();
} else {
warpOrder = triton::gpu::getWarpOrder(layout);
}
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
Value warpSize = i32_val(32);
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
// TODO: [DOT LL]
// The delinearize function is not entirely correct for certain layouts,
// such as wgmma. The correct approach is to convert a legacy layout to its
// corresponding linear layout and use the linear layout's
// getFreeVariableMasks to identify redundant elements.
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTATile[dim])
continue;
// Otherwise, we need to mask threads that will replicate data on this
// dimension. Calculate the thread index on this dimension for the CTA
Value threadDim =
add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])),
multiDimThreadId[dim]);
mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])),
i32_val(shape[dim])));
}
// Do not write duplicated data when multicast is enabled
if (triton::gpu::getNumCTAs(layout) > 1) {
auto _0 = i32_val(0);
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout);
auto CTASplitNum = triton::gpu::getCTASplitNum(layout);
auto CTAOrder = triton::gpu::getCTAOrder(layout);

auto multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);

for (unsigned dim = 0; dim < rank; ++dim) {
// Skip when multicast is not enabled in this dimension
if (CTAsPerCGA[dim] == CTASplitNum[dim])
continue;
// This wrapping rule must be consistent with emitCTAOffsetForLayout
unsigned splitNum = std::min<unsigned>(shape[dim], CTASplitNum[dim]);
Value repId = udiv(multiDimClusterCTAId[dim], i32_val(splitNum));
// Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]:
// CTA0 and CTA2 holds data of block0,
// CTA1 and CTA3 holds data of block1.
// Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should
// be masked. We add the following mask:
// multiDimClusterCTAId[dim] / splitNum == 0
// Actually in all existing cases of multicast, splitNum is always 1.
// The mask is equivalent to:
// multiDimClusterCTAId[dim] == 0
mask = and_(mask, icmp_eq(repId, _0));
Value warpSize =
i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(moduleOp));
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
// Step 2: check lane and warp redundancy
auto laneMasks = freeVariableMasks[kLane];
auto warpMasks = freeVariableMasks[kWarp];
mask = and_(mask, icmp_eq(and_(i32_val(laneMasks), laneId), i32_val(0)));
mask = and_(mask, icmp_eq(and_(i32_val(warpMasks), warpId), i32_val(0)));
if (numCTAs > 1) {
// Step 3: check block redundancy
auto ctaId = targetInfo.getClusterCTAId(rewriter, loc);
auto ctaMasks = freeVariableMasks[kBlock];
mask = and_(mask, icmp_eq(and_(i32_val(ctaMasks), ctaId), i32_val(0)));
}
}
} else {
// If the tensor is not ranked, then it is a scalar and only thread 0 of
// CTA0 can write
mask = and_(mask, icmp_eq(clusterCTAId, i32_val(0)));
mask = and_(mask, icmp_eq(tid, i32_val(0)));
if (numCTAs > 1) {
auto ctaId = targetInfo.getClusterCTAId(rewriter, loc);
// If the tensor is not ranked, then it is a scalar and only thread 0 of
// CTA0 within the cluster can write
mask = and_(mask, icmp_eq(ctaId, i32_val(0)));
}
}
return mask;
}
Expand Down Expand Up @@ -264,7 +235,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,

PTXBuilder ptxBuilder;

Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
Value pred = mask ? maskElems[vecStart] : true_val();

const std::string readConstraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
Expand Down Expand Up @@ -437,7 +408,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
<< mask << "\n";
}

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
auto moduleOp = op->getParentOfType<ModuleOp>();
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNBits = dtsize * 8;
Expand Down Expand Up @@ -485,6 +456,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
PTXBuilder ptxBuilder;
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);

Value mask = getRedundantDataMask(moduleOp, valueTy, rewriter, loc,
vecStart, targetInfo);
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;

auto *asmAddr =
Expand Down Expand Up @@ -577,7 +550,6 @@ struct AtomicCASOpConversion
<< " origin vec = " << vecOrig
<< " elemsPerThread = " << elemsPerThread << "\n";

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);

Expand Down Expand Up @@ -607,6 +579,8 @@ struct AtomicCASOpConversion
os << op.getSem();
auto scope = stringifyMemSyncScope(op.getScope()).str();
atom.global().o(semStr).o(scope).o("cas").o(sTy);
Value mask =
getRedundantDataMask(moduleOp, valueTy, rewriter, loc, i, targetInfo);
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);

if (tensorTy) {
Expand Down Expand Up @@ -736,12 +710,12 @@ struct AtomicRMWOpConversion
<< " packed = " << packed << " origin vec = " << vecOrig
<< " numElems = " << numElems;

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);

auto packedTy = vec_ty(valueElemTy, packed);
SmallVector<Value> resultVals(elemsPerThread);
for (size_t i = 0; i < elemsPerThread; i += vec * packed) {
Value rmwPtr = ptrElements[i];
Value mask =
getRedundantDataMask(moduleOp, valueTy, rewriter, loc, i, targetInfo);
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
std::string sTy;
PTXBuilder ptxBuilderAtomicRMW;
Expand Down Expand Up @@ -976,6 +950,7 @@ struct AsyncCopyGlobalToLocalOpConversion
<< vecBytes << " bytes";
}

auto moduleOp = op->getParentOfType<ModuleOp>();
for (int i = 0; i < shmemAddrs.size(); i++) {
// It's possible that vecTy is larger than 128 bits, in which case we have
// to use multiple cp.async instructions.
Expand Down Expand Up @@ -1003,24 +978,26 @@ struct AsyncCopyGlobalToLocalOpConversion
// if there's any mask. cp.async will automatically fill the
// remaining slots with 0 if cp-size > src-size.
// XXX(Keren): Always assume other = 0 for now.
// When 'other != 0' is supported, we will need to fold the
// op.getMask() and redundantDataMask() into the same predicate, the
// way it is done for LoadOp.
auto selectOp =
select(maskElems[elemIdx], i32_val(wordBytes), i32_val(0));
srcSize = ptxBuilder.newOperand(selectOp, "r");
}

// When 'other != 0' is supported, we will need to fold the op.getMask()
// and redundantDataMask() into the same predicate, the way it is done
// for LoadOp.
Value maskVal = redundantDataMask(srcTy, rewriter, loc, targetInfo);

// TODO: Masking does not work for CTA multicast with cp.async. This is
// a quick and dirty workaround to avoid the issue.
bool skipMaskForMultiCTA = triton::gpu::getNumCTAs(srcLayout) > 1;
if (!skipMaskForMultiCTA) {
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize)
.predicate(maskVal);
} else {
if (skipMaskForMultiCTA) {
// TODO: Masking does not work for CTA multicast with cp.async.
// XXX(@peterbell10): In the multi-CTA mode, the redundant data might
// be on different CTAs which don't share the same smem address space,
// so we might need to load the same data multiple times.
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
} else {
Value mask = getRedundantDataMask(moduleOp, srcTy, rewriter, loc,
elemIdx, targetInfo);
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize)
.predicate(mask);
}
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
}
Expand Down

0 comments on commit e558838

Please sign in to comment.