Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LAYOUTS] Use LLs for Hopper whenever we wouldn't use ldmatrix #5235

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ python/triton/language/extra
# Proton
python/triton/profiler

# Pytest
pytest.ini

Comment on lines +29 to +31
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny .gitignore change for pytest.

# Instrumentation
python/triton/instrumentation

Expand Down
7 changes: 7 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Conversion/MLIRTypes.h"

namespace mlir::triton {

class TargetInfoBase {
public:
virtual bool supportMaximumMinimum() const = 0;
Expand Down Expand Up @@ -37,6 +38,12 @@ class TargetInfoBase {
pred);
}

virtual bool canUseStMatrix(RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize) const = 0;

virtual void storeMatrixShared(RewriterBase &rewriter, Location loc,
Value ptr, Value val) const = 0;

Expand Down
10 changes: 5 additions & 5 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
// bit width of the tensor in the future to support more flexible tensor
// encodings
std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize);
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ compared to 1*64 when the hasLeadingOffset is false.
int k = (needTrans) ? matShape[0] : matShape[2];
int vec = (order[0] == rank-1) ? k : m;
int mmaStride = (order[0] == rank-1) ? m : k;
int maxPhase = mmaStride / perPhase;
int maxPhase = std::max(mmaStride / perPhase, 1);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

Expand All @@ -373,7 +373,7 @@ compared to 1*64 when the hasLeadingOffset is false.
int k = needTrans ? matShape[1] : matShape[2];
int vec = (order[0] == rank-1) ? n : k;
int mmaStride = (order[0] == rank-1) ? k : n;
int maxPhase = mmaStride / perPhase;
int maxPhase = std::max(mmaStride / perPhase, 1);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
dotOperandLayout.getOpIdx() == 0 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
(elementTypeSize == 16 || elementTypeSize == 8) &&
dotOperandLayout.getKWidth() == 32 / elementTypeSize;
return ans;
}

Expand Down
61 changes: 26 additions & 35 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,28 +376,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// completed before we can remove the layoutIsOK check:
// 1. Support for AMD's WMMA
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
return !useLegacyMMAConversion;
}
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto parent = dotOperand.getParent();
if (isa<MmaEncodingTrait>(parent) && useLegacyMMAConversion) {
return false;
}
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (nvidiaMma.isAmpere()) {
return true;
}
}
if (isa<AMDMfmaEncodingAttr>(parent)) {
return true;
}
return false;
layout = dotOperand.getParent();
Copy link
Contributor

@binarman binarman Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this is a correct change.
This code will tell "true" for Blocked FMA dot operands, but I don't see linear converter for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree that blocked FMA has to be fixed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, right, let me send a fix

}
if (isa<BlockedEncodingAttr>(layout)) {
return true;

if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
return !useLegacyMMAConversion;
}
if (isa<LinearEncodingAttr>(layout)) {
if (isa<BlockedEncodingAttr, LinearEncodingAttr>(layout)) {
return true;
}
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
Expand All @@ -408,6 +394,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) {
return failure();
}
// FIXME [Dot LL] Remove this once we implement this trick in LLs
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) {
return failure();
}

// The following check can be removed when generalized warp shuffle
// conversions are ready:
Expand Down Expand Up @@ -504,34 +494,35 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// don't need to avoid duplicate writes.
// Input dims: [reg, lane, warp]
// Output dims: [offset, iteration]
std::optional<LinearLayout> shmemStoreLayout =
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
bool isStMatrix = shmemStoreLayout.has_value();
if (!isStMatrix) {
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);
}
assert(shmemStoreLayout.has_value());
bool isStMatrix = targetInfo.canUseStMatrix(
op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
LinearLayout shmemStoreLayout =
isStMatrix ? chooseStMatrixLayout(
ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0)
: srcLayout.invertAndCompose(sharedLayout);

const int shmemAllocatedNumElems =
getNumScratchElements(scratchConfig.paddedRepShape);
assert(shmemStoreLayout->getOutDimSize(kOffset) <= shmemAllocatedNumElems);
assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems);

// Layout for the load from shmem to registers.
LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout);

// Check that the `register` fully determines the `iteration`. That is,
// each thread does exactly the same reads and writes to shmem on each
// iteration, just with different input/output registers.
assert(shmemStoreLayout->sublayoutIsZero({kLane, kWarp, kBlock},
{kIteration}));
assert(
shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
assert(
shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));

// iteration -> registers
SmallVector<SmallVector<int>> inRegsForIter =
collectRegsForIter(ctx, *shmemStoreLayout);
collectRegsForIter(ctx, shmemStoreLayout);
SmallVector<SmallVector<int>> outRegsForIter =
collectRegsForIter(ctx, shmemLoadLayout);

Expand Down Expand Up @@ -588,7 +579,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return vecAddr;
};

auto storeBase = applyLinearLayout(loc, rewriter, *shmemStoreLayout,
auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout,
{{kRegister, i32_val(0)},
{kLane, laneId},
{kWarp, warpId},
Expand All @@ -611,11 +602,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion

// When using `stmatrix`, we can store `inVec` elements even if they are
// not contiguous
auto inVec = isStMatrix ? shmemStoreLayout->getNumConsecutiveInOut()
auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut()
: scratchConfig.inVec;
for (int j = 0; j < inVals.size() / iterations; j += inVec) {
auto inRegSlice = inRegs[j];
Value vecAddr = getVecAddr(*shmemStoreLayout, storeBase, inRegSlice);
Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice);
SmallVector<Value> inValsVec;
for (int k = 0; k < inVec; k++)
inValsVec.push_back(inVals[inRegSlice + k]);
Expand Down
44 changes: 29 additions & 15 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,34 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedDotOpLayout(RankedTensorType type) {
auto layout = type.getEncoding();
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
static bool isSupportedDotOpLayout(MemDescType srcTy,
RankedTensorType dstTy) {
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
auto bitwidth = dstTy.getElementTypeBitWidth();
auto rank = dstTy.getRank();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
auto vecWidth = 32 / bitwidth;
auto kWidth = dot.getKWidth();
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
// - kWidth == 8
// - kWidth == 4, bitwidth = 32
auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2;
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
auto needTrans = kOrder != srcLayout.getOrder()[0];
auto canUseLdmatrix =
(bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
if (mma.isHopper()) {
// I think we should be able to remove this condition, but it's here
// as the legacy ldmatrix path does not support it
canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32;
}
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
// though
canUseLdmatrix &=
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
// To be removed in https://github.com/triton-lang/triton/pull/5154
bool legacyLoweringIsBuggy =
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
return legacyLoweringIsBuggy && mma.isAmpere();
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere();
return (mma.isHopper() && !canUseLdmatrix) ||
(mma.isAmpere() && legacyLoweringIsBuggy);
}
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
return true;
Expand All @@ -162,12 +178,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
ConversionPatternRewriter &rewriter) const override {
MemDescType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (isa<SharedEncodingAttr>(srcLayout) &&
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(dstTy))) {
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(srcTy, dstTy)) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
Expand Down Expand Up @@ -206,7 +220,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
Expand Down
58 changes: 8 additions & 50 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
llvm::report_fatal_error("Illegal shared layout");
}

int vec = 8 * 16 / elemBitWidth;
if (vec != shared.getVec()) {
llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec
<< ": " << shared << "\n";
llvm::report_fatal_error("Illegal shared layout");
}
int vec = shared.getVec();

StringAttr colDimName = outDimNames[colDim];
StringAttr rowDimName = outDimNames[rowDim];
Expand Down Expand Up @@ -858,40 +853,7 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
}

namespace {

// TODO (Keren): Currently, we have more restrictions than necessary when using
// stmatrix. These restrictions are retained from legacy code, and we could
// relax some of them in the future.
bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
auto mmaLayout =
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
if (!mmaLayout || !mmaLayout.isHopper())
return false;
if (isa<PointerType>(tensorTy.getElementType()))
return false;
if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16)
return false;
if (order[0] != 1)
return false;

auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape());
if (tensorShapePerCTA.size() != 2)
return false;
auto numIterations = ceil<unsigned>(tensorShapePerCTA[1], repShape[1]) *
ceil<unsigned>(tensorShapePerCTA[0], repShape[0]);
if (numIterations > 1)
return false;
if (paddedRepShape[1] % 8 != 0)
return false;
if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 &&
swizzleByteSize != 128)
return false;
return true;
}

std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
LinearLayout chooseStMatrixLayoutLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
Expand Down Expand Up @@ -962,7 +924,7 @@ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
.reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
LinearLayout chooseStMatrixLayoutNoLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
StringAttr kReg = S("register");
Expand Down Expand Up @@ -1002,15 +964,11 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(

} // anonymous namespace

std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize) {
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order,
swizzleByteSize))
return std::nullopt;

LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize) {
if (swizzleByteSize == 0)
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
paddedRepShape, order);
Expand Down
15 changes: 12 additions & 3 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
dtypes_with_bfloat16,
is_cuda,
is_interpreter,
is_hopper,
is_hip,
is_hip_cdna,
is_hip_mi200,
Expand Down Expand Up @@ -195,7 +196,12 @@ def is_layout_applicable(layout) -> bool:
if layout in common_layouts:
return True
elif is_cuda():
return isinstance(layout, MmaLayout)
mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout
if not isinstance(mma_layout, MmaLayout):
return False
if mma_layout.version[0] >= 3 and not is_hopper():
lezcano marked this conversation as resolved.
Show resolved Hide resolved
return False
return True
elif is_hip():
target_arch = triton.runtime.driver.active.get_current_target().arch
if "gfx11" in target_arch:
Expand Down Expand Up @@ -5250,6 +5256,9 @@ def kernel(Out):
# TODO: backend should be tested separately

layouts = [
MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]),
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2),
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1),
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
Expand Down Expand Up @@ -5297,9 +5306,9 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):

@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("src_layout", filter_layouts(layouts))
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
@pytest.mark.parametrize("dst_layout", layouts)
@pytest.mark.parametrize("dst_layout", filter_layouts(layouts))
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path):
if str(src_layout) == str(dst_layout):
pytest.skip()
Expand Down
Loading
Loading