Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] remove redundant LDS bypass checks #5002

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

assert(!isMfmaToDotShortcut(srcTy, dstTy));
assert(cvtNeedsSharedMemory(srcTy, dstTy));

// FIXME This is NOT entirely correct
// This should be getElemOrder, but we don't have such a method
Expand Down
19 changes: 1 addition & 18 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,22 +605,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
return matrixDimsCompatible && bDimCompatible;
}

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
return false;
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
// improved. In addition, we can enable this shortcut for regular MFMA
// layout when opIdx == 1.
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}

// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
Expand Down Expand Up @@ -738,8 +722,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
}

bool atomicNeedsSharedMemory(Value value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,64 +115,13 @@ struct LocalLoadOpConversion
}
};

struct ConvertLayoutOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
using ConvertOpToLLVMPattern<
triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = op.getSrc();
Value dst = op.getResult();
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<RankedTensorType>(dst.getType());
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

if (isa<AMDMfmaEncodingAttr>(srcLayout) &&
isa<DotOperandEncodingAttr>(dstLayout)) {
return lowerMfmaToDotOperand(op, adaptor, rewriter);
}
return failure();
}

private:
LogicalResult
lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
RankedTensorType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
if (isMfmaToDotShortcut(srcTy, dstTy)) {
// vecSize is an number of sequential elements stored by one thread
// - For MFMA encoding (encoding of the result tensor of dot
// operation) it is 4
// - For MFMA operand encoding it is
// dotOperandEncoding::kWidth,
// which is 4 in certain cases (e.g. fp16 and bfloat16 dtypes with kpack
// = 1)
//
// For cases where these two values are equal MFMA and MFMA operand
// layouts are the same.
auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
Value view =
packLLElements(loc, getTypeConverter(), vals, rewriter, dstTy);
rewriter.replaceOp(op, view);
return success();
}
return failure();
}
};
} // namespace

namespace mlir::triton::AMD {
void populateConvertLayoutOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, benefit);
}
} // namespace mlir::triton::AMD
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ struct DecomposeUnsupportedAMDConversions

triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod);

triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod,
isMfmaToDotShortcut);
auto isShortcut =
mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory));

triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut);

/* -------------------------------- */
// Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op`
Expand Down
Loading