Skip to content

Commit

Permalink
[LAYOUTS] Use LLs for Hopper whenever we wouldn't use ldmatrix (#5235)
Browse files Browse the repository at this point in the history
The legacy path has some bugs for cases like `kWidth=1`. I'm starting to
port Hopper to use LLs to try to isolate them.
  • Loading branch information
lezcano authored Nov 26, 2024
1 parent deee78f commit e2dc77b
Show file tree
Hide file tree
Showing 15 changed files with 164 additions and 124 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ python/triton/language/extra
# Proton
python/triton/profiler

# Pytest
pytest.ini

# Instrumentation
python/triton/instrumentation

Expand Down
7 changes: 7 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Conversion/MLIRTypes.h"

namespace mlir::triton {

class TargetInfoBase {
public:
virtual bool supportMaximumMinimum() const = 0;
Expand Down Expand Up @@ -37,6 +38,12 @@ class TargetInfoBase {
pred);
}

virtual bool canUseStMatrix(RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize) const = 0;

virtual void storeMatrixShared(RewriterBase &rewriter, Location loc,
Value ptr, Value val) const = 0;

Expand Down
10 changes: 5 additions & 5 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
// bit width of the tensor in the future to support more flexible tensor
// encodings
std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize);
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ compared to 1*64 when the hasLeadingOffset is false.
int k = (needTrans) ? matShape[0] : matShape[2];
int vec = (order[0] == rank-1) ? k : m;
int mmaStride = (order[0] == rank-1) ? m : k;
int maxPhase = mmaStride / perPhase;
int maxPhase = std::max(mmaStride / perPhase, 1);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

Expand All @@ -373,7 +373,7 @@ compared to 1*64 when the hasLeadingOffset is false.
int k = needTrans ? matShape[1] : matShape[2];
int vec = (order[0] == rank-1) ? n : k;
int mmaStride = (order[0] == rank-1) ? k : n;
int maxPhase = mmaStride / perPhase;
int maxPhase = std::max(mmaStride / perPhase, 1);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,8 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
dotOperandLayout.getOpIdx() == 0 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
(elementTypeSize == 16 || elementTypeSize == 8) &&
dotOperandLayout.getKWidth() == 32 / elementTypeSize;
return ans;
}

Expand Down
61 changes: 26 additions & 35 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,28 +376,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// completed before we can remove the layoutIsOK check:
// 1. Support for AMD's WMMA
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
return !useLegacyMMAConversion;
}
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto parent = dotOperand.getParent();
if (isa<MmaEncodingTrait>(parent) && useLegacyMMAConversion) {
return false;
}
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (nvidiaMma.isAmpere()) {
return true;
}
}
if (isa<AMDMfmaEncodingAttr>(parent)) {
return true;
}
return false;
layout = dotOperand.getParent();
}
if (isa<BlockedEncodingAttr>(layout)) {
return true;

if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
return !useLegacyMMAConversion;
}
if (isa<LinearEncodingAttr>(layout)) {
if (isa<BlockedEncodingAttr, LinearEncodingAttr>(layout)) {
return true;
}
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
Expand All @@ -408,6 +394,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) {
return failure();
}
// FIXME [Dot LL] Remove this once we implement this trick in LLs
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) {
return failure();
}

assert(cvtNeedsSharedMemory(srcTy, dstTy));

Expand Down Expand Up @@ -498,34 +488,35 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// don't need to avoid duplicate writes.
// Input dims: [reg, lane, warp]
// Output dims: [offset, iteration]
std::optional<LinearLayout> shmemStoreLayout =
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
bool isStMatrix = shmemStoreLayout.has_value();
if (!isStMatrix) {
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);
}
assert(shmemStoreLayout.has_value());
bool isStMatrix = targetInfo.canUseStMatrix(
op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
LinearLayout shmemStoreLayout =
isStMatrix ? chooseStMatrixLayout(
ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0)
: srcLayout.invertAndCompose(sharedLayout);

const int shmemAllocatedNumElems =
getNumScratchElements(scratchConfig.paddedRepShape);
assert(shmemStoreLayout->getOutDimSize(kOffset) <= shmemAllocatedNumElems);
assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems);

// Layout for the load from shmem to registers.
LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout);

// Check that the `register` fully determines the `iteration`. That is,
// each thread does exactly the same reads and writes to shmem on each
// iteration, just with different input/output registers.
assert(shmemStoreLayout->sublayoutIsZero({kLane, kWarp, kBlock},
{kIteration}));
assert(
shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
assert(
shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));

// iteration -> registers
SmallVector<SmallVector<int>> inRegsForIter =
collectRegsForIter(ctx, *shmemStoreLayout);
collectRegsForIter(ctx, shmemStoreLayout);
SmallVector<SmallVector<int>> outRegsForIter =
collectRegsForIter(ctx, shmemLoadLayout);

Expand Down Expand Up @@ -582,7 +573,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return vecAddr;
};

auto storeBase = applyLinearLayout(loc, rewriter, *shmemStoreLayout,
auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout,
{{kRegister, i32_val(0)},
{kLane, laneId},
{kWarp, warpId},
Expand All @@ -605,11 +596,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion

// When using `stmatrix`, we can store `inVec` elements even if they are
// not contiguous
auto inVec = isStMatrix ? shmemStoreLayout->getNumConsecutiveInOut()
auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut()
: scratchConfig.inVec;
for (int j = 0; j < inVals.size() / iterations; j += inVec) {
auto inRegSlice = inRegs[j];
Value vecAddr = getVecAddr(*shmemStoreLayout, storeBase, inRegSlice);
Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice);
SmallVector<Value> inValsVec;
for (int k = 0; k < inVec; k++)
inValsVec.push_back(inVals[inRegSlice + k]);
Expand Down
44 changes: 29 additions & 15 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,34 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedDotOpLayout(RankedTensorType type) {
auto layout = type.getEncoding();
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
static bool isSupportedDotOpLayout(MemDescType srcTy,
RankedTensorType dstTy) {
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
auto bitwidth = dstTy.getElementTypeBitWidth();
auto rank = dstTy.getRank();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
auto vecWidth = 32 / bitwidth;
auto kWidth = dot.getKWidth();
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
// - kWidth == 8
// - kWidth == 4, bitwidth = 32
auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2;
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
auto needTrans = kOrder != srcLayout.getOrder()[0];
auto canUseLdmatrix =
(bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
if (mma.isHopper()) {
// I think we should be able to remove this condition, but it's here
// as the legacy ldmatrix path does not support it
canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32;
}
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
// though
canUseLdmatrix &=
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
// To be removed in https://github.com/triton-lang/triton/pull/5154
bool legacyLoweringIsBuggy =
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
return legacyLoweringIsBuggy && mma.isAmpere();
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere();
return (mma.isHopper() && !canUseLdmatrix) ||
(mma.isAmpere() && legacyLoweringIsBuggy);
}
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
return true;
Expand All @@ -162,12 +178,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
ConversionPatternRewriter &rewriter) const override {
MemDescType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (isa<SharedEncodingAttr>(srcLayout) &&
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(dstTy))) {
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(srcTy, dstTy)) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
Expand Down Expand Up @@ -206,7 +220,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
Expand Down
58 changes: 8 additions & 50 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
llvm::report_fatal_error("Illegal shared layout");
}

int vec = 8 * 16 / elemBitWidth;
if (vec != shared.getVec()) {
llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec
<< ": " << shared << "\n";
llvm::report_fatal_error("Illegal shared layout");
}
int vec = shared.getVec();

StringAttr colDimName = outDimNames[colDim];
StringAttr rowDimName = outDimNames[rowDim];
Expand Down Expand Up @@ -858,40 +853,7 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
}

namespace {

// TODO (Keren): Currently, we have more restrictions than necessary when using
// stmatrix. These restrictions are retained from legacy code, and we could
// relax some of them in the future.
bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
auto mmaLayout =
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
if (!mmaLayout || !mmaLayout.isHopper())
return false;
if (isa<PointerType>(tensorTy.getElementType()))
return false;
if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16)
return false;
if (order[0] != 1)
return false;

auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape());
if (tensorShapePerCTA.size() != 2)
return false;
auto numIterations = ceil<unsigned>(tensorShapePerCTA[1], repShape[1]) *
ceil<unsigned>(tensorShapePerCTA[0], repShape[0]);
if (numIterations > 1)
return false;
if (paddedRepShape[1] % 8 != 0)
return false;
if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 &&
swizzleByteSize != 128)
return false;
return true;
}

std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
LinearLayout chooseStMatrixLayoutLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
Expand Down Expand Up @@ -962,7 +924,7 @@ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
.reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
LinearLayout chooseStMatrixLayoutNoLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
StringAttr kReg = S("register");
Expand Down Expand Up @@ -1002,15 +964,11 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(

} // anonymous namespace

std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize) {
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order,
swizzleByteSize))
return std::nullopt;

LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize) {
if (swizzleByteSize == 0)
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
paddedRepShape, order);
Expand Down
15 changes: 12 additions & 3 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
dtypes_with_bfloat16,
is_cuda,
is_interpreter,
is_hopper,
is_hip,
is_hip_cdna,
is_hip_mi200,
Expand Down Expand Up @@ -195,7 +196,12 @@ def is_layout_applicable(layout) -> bool:
if layout in common_layouts:
return True
elif is_cuda():
return isinstance(layout, MmaLayout)
mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout
if not isinstance(mma_layout, MmaLayout):
return False
if mma_layout.version[0] >= 3 and not is_hopper():
return False
return True
elif is_hip():
target_arch = triton.runtime.driver.active.get_current_target().arch
if "gfx11" in target_arch:
Expand Down Expand Up @@ -5246,6 +5252,9 @@ def kernel(Out):
# TODO: backend should be tested separately

layouts = [
MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]),
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2),
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1),
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
Expand Down Expand Up @@ -5293,9 +5302,9 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):

@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("src_layout", filter_layouts(layouts))
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
@pytest.mark.parametrize("dst_layout", layouts)
@pytest.mark.parametrize("dst_layout", filter_layouts(layouts))
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path):
if str(src_layout) == str(dst_layout):
pytest.skip()
Expand Down
Loading

0 comments on commit e2dc77b

Please sign in to comment.