Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Use LL to simplify redundant elements check and fix related issues #5225

Merged
merged 12 commits into from
Nov 22, 2024
21 changes: 5 additions & 16 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5436,21 +5436,11 @@ def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
pytest.skip("Skip testing MMAv3 on devices with CC < 9")

num_warps = np.cumprod(src_layout.warps_per_cta)[-1]
# TODO(Keren): Remove the intermediate layout once we have resolved the redundantDataMask issue for WGMMA
warps_per_cta = src_layout.warps_per_cta
interm = BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [warps_per_cta[0], warps_per_cta[1]], [0, 1], [1, 1],
[1, 1], [0, 1])

def do_test(src_layout, dst_layout):
layouts = f"""
#src = {src_layout}
#dst = {dst_layout}
#interm = {interm}
"""

conversion = f"""
%12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
%13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
"""

ir = layouts + f"""
Expand All @@ -5460,6 +5450,7 @@ def do_test(src_layout, dst_layout):
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
%3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #dst>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src>
Expand All @@ -5468,12 +5459,10 @@ def do_test(src_layout, dst_layout):
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<f16>, #src>
%3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #interm>
""" + conversion + f"""
%15 = triton_gpu.convert_layout %12 : tensor<{M}x{N}xi32, #dst> -> tensor<{M}x{N}xi32, #interm>
%16 = triton_gpu.convert_layout %13 : tensor<{M}x{N}xf16, #dst> -> tensor<{M}x{N}xf16, #interm>
%17 = tt.addptr %3, %15 : tensor<{M}x{N}x!tt.ptr<f16>, #interm>, tensor<{M}x{N}xi32, #interm>
tt.store %17, %16 : tensor<{M}x{N}x!tt.ptr<f16>, #interm>
%12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
%13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>, tensor<{M}x{N}xi32, #dst>
tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>
tt.return
}}
}}
Expand Down
151 changes: 64 additions & 87 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

using namespace mlir;
Expand All @@ -24,87 +25,57 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
namespace {

// Return the mask for the unique data accessed by given tensor type.
// Used to mask out the redundant data accessed by threads.
Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
Location loc, const NVIDIA::TargetInfo &targetInfo) {
// NOTE: Redundant memory load is allowed in triton, but redundant memory store
// is not allowed.
// mask = true: thread can write
// mask = false: thread should not write
Value getRedundantDataMask(ModuleOp moduleOp, Type valueTy,
ConversionPatternRewriter &rewriter, Location loc,
int regIdx, const NVIDIA::TargetInfo &targetInfo) {
auto ctx = moduleOp.getContext();
auto tensorTy = dyn_cast<RankedTensorType>(valueTy);
Value mask = int_val(1, 1);
auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
auto tid = tid_val();
auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc);
auto mask = int_val(1, 1);
lezcano marked this conversation as resolved.
Show resolved Hide resolved
auto kReg = str_attr("register");
auto kLane = str_attr("lane");
auto kWarp = str_attr("warp");
auto kBlock = str_attr("block");
if (tensorTy) {
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
unsigned rank = shape.size();
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto threadOrder = triton::gpu::getThreadOrder(layout);
SmallVector<unsigned> warpOrder(rank);
if (auto enc = dyn_cast<DotOperandEncodingAttr>(layout)) {
warpOrder =
triton::gpu::getMatrixOrder(rank, /*rowMajor=*/enc.getOpIdx() == 1);
auto layout = tensorTy.getEncoding();
auto ll = triton::gpu::toLinearLayout(shape, layout);
assert(ll.has_value() && "Failed to convert layout to linear layout");
auto freeVariableMasks = ll->getFreeVariableMasks();
auto regMasks = freeVariableMasks[kReg];
if (regMasks & regIdx) {
lezcano marked this conversation as resolved.
Show resolved Hide resolved
// Step 1: check register redundancy
mask = int_val(1, 0);
lezcano marked this conversation as resolved.
Show resolved Hide resolved
} else {
warpOrder = triton::gpu::getWarpOrder(layout);
}
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
Value warpSize = i32_val(32);
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
// TODO: [DOT LL]
// The delinearize function is not entirely correct for certain layouts,
// such as wgmma. The correct approach is to convert a legacy layout to its
// corresponding linear layout and use the linear layout's
// getFreeVariableMasks to identify redundant elements.
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTATile[dim])
continue;
// Otherwise, we need to mask threads that will replicate data on this
// dimension. Calculate the thread index on this dimension for the CTA
Value threadDim =
add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])),
multiDimThreadId[dim]);
mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])),
i32_val(shape[dim])));
}
// Do not write duplicated data when multicast is enabled
if (triton::gpu::getNumCTAs(layout) > 1) {
auto _0 = i32_val(0);
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout);
auto CTASplitNum = triton::gpu::getCTASplitNum(layout);
auto CTAOrder = triton::gpu::getCTAOrder(layout);

auto multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);

for (unsigned dim = 0; dim < rank; ++dim) {
// Skip when multicast is not enabled in this dimension
if (CTAsPerCGA[dim] == CTASplitNum[dim])
continue;
// This wrapping rule must be consistent with emitCTAOffsetForLayout
unsigned splitNum = std::min<unsigned>(shape[dim], CTASplitNum[dim]);
Value repId = udiv(multiDimClusterCTAId[dim], i32_val(splitNum));
// Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]:
// CTA0 and CTA2 holds data of block0,
// CTA1 and CTA3 holds data of block1.
// Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should
// be masked. We add the following mask:
// multiDimClusterCTAId[dim] / splitNum == 0
// Actually in all existing cases of multicast, splitNum is always 1.
// The mask is equivalent to:
// multiDimClusterCTAId[dim] == 0
mask = and_(mask, icmp_eq(repId, _0));
Value warpSize =
i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(moduleOp));
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
// Step 2: check lane and warp redundancy
auto laneMasks = freeVariableMasks[kLane];
auto warpMasks = freeVariableMasks[kWarp];
mask = and_(mask, icmp_eq(and_(i32_val(laneMasks), laneId), i32_val(0)));
mask = and_(mask, icmp_eq(and_(i32_val(warpMasks), warpId), i32_val(0)));
if (numCTAs > 1) {
// Step 3: check block redundancy
auto ctaId = targetInfo.getClusterCTAId(rewriter, loc);
auto ctaMasks = freeVariableMasks[kBlock];
mask = and_(mask, icmp_eq(and_(i32_val(ctaMasks), ctaId), i32_val(0)));
}
}
} else {
// If the tensor is not ranked, then it is a scalar and only thread 0 of
// CTA0 can write
mask = and_(mask, icmp_eq(clusterCTAId, i32_val(0)));
mask = and_(mask, icmp_eq(tid, i32_val(0)));
if (numCTAs > 1) {
auto ctaId = targetInfo.getClusterCTAId(rewriter, loc);
// If the tensor is not ranked, then it is a scalar and only thread 0 of
// CTA0 within the cluster can write
mask = and_(mask, icmp_eq(ctaId, i32_val(0)));
}
}
return mask;
}
Expand Down Expand Up @@ -437,7 +408,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
<< mask << "\n";
}

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
auto moduleOp = op->getParentOfType<ModuleOp>();
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNBits = dtsize * 8;
Expand Down Expand Up @@ -485,6 +456,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
PTXBuilder ptxBuilder;
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);

Value mask = getRedundantDataMask(moduleOp, valueTy, rewriter, loc,
vecStart, targetInfo);
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;

auto *asmAddr =
Expand Down Expand Up @@ -577,7 +550,6 @@ struct AtomicCASOpConversion
<< " origin vec = " << vecOrig
<< " elemsPerThread = " << elemsPerThread << "\n";

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);

Expand Down Expand Up @@ -607,6 +579,8 @@ struct AtomicCASOpConversion
os << op.getSem();
auto scope = stringifyMemSyncScope(op.getScope()).str();
atom.global().o(semStr).o(scope).o("cas").o(sTy);
Value mask =
getRedundantDataMask(moduleOp, valueTy, rewriter, loc, i, targetInfo);
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);

if (tensorTy) {
Expand Down Expand Up @@ -736,12 +710,12 @@ struct AtomicRMWOpConversion
<< " packed = " << packed << " origin vec = " << vecOrig
<< " numElems = " << numElems;

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);

auto packedTy = vec_ty(valueElemTy, packed);
SmallVector<Value> resultVals(elemsPerThread);
for (size_t i = 0; i < elemsPerThread; i += vec * packed) {
Value rmwPtr = ptrElements[i];
Value mask =
getRedundantDataMask(moduleOp, valueTy, rewriter, loc, i, targetInfo);
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
std::string sTy;
PTXBuilder ptxBuilderAtomicRMW;
Expand Down Expand Up @@ -976,6 +950,7 @@ struct AsyncCopyGlobalToLocalOpConversion
<< vecBytes << " bytes";
}

auto moduleOp = op->getParentOfType<ModuleOp>();
for (int i = 0; i < shmemAddrs.size(); i++) {
// It's possible that vecTy is larger than 128 bits, in which case we have
// to use multiple cp.async instructions.
Expand Down Expand Up @@ -1003,24 +978,26 @@ struct AsyncCopyGlobalToLocalOpConversion
// if there's any mask. cp.async will automatically fill the
// remaining slots with 0 if cp-size > src-size.
// XXX(Keren): Always assume other = 0 for now.
// When 'other != 0' is supported, we will need to fold the
// op.getMask() and redundantDataMask() into the same predicate, the
// way it is done for LoadOp.
auto selectOp =
select(maskElems[elemIdx], i32_val(wordBytes), i32_val(0));
srcSize = ptxBuilder.newOperand(selectOp, "r");
}

// When 'other != 0' is supported, we will need to fold the op.getMask()
// and redundantDataMask() into the same predicate, the way it is done
// for LoadOp.
Value maskVal = redundantDataMask(srcTy, rewriter, loc, targetInfo);

// TODO: Masking does not work for CTA multicast with cp.async. This is
// a quick and dirty workaround to avoid the issue.
bool skipMaskForMultiCTA = triton::gpu::getNumCTAs(srcLayout) > 1;
if (!skipMaskForMultiCTA) {
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize)
.predicate(maskVal);
} else {
if (skipMaskForMultiCTA) {
// TODO: Masking does not work for CTA multicast with cp.async.
// XXX(@peterbell10): In the multi-CTA mode, the redundant data might
// be on different CTAs which don't share the same smem address space,
// so we might need to load the same data multiple times.
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
} else {
Value mask = getRedundantDataMask(moduleOp, srcTy, rewriter, loc,
elemIdx, targetInfo);
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize)
.predicate(mask);
}
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
}
Expand Down
Loading