diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 37d24ac929a9..df6029db0de2 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -214,8 +214,6 @@ bool atomicNeedsSharedMemory(Value result); bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); -bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); - // Return true if the src and dst layout match. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 276a6e7004df..665b97aeebba 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - assert(!isMfmaToDotShortcut(srcTy, dstTy)); + assert(cvtNeedsSharedMemory(srcTy, dstTy)); // FIXME This is NOT entirely correct // This should be getElemOrder, but we don't have such a method diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 30ba11c31782..aa9f8b01eae1 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -605,22 +605,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { return matrixDimsCompatible && bDimCompatible; } -bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { - auto mfmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); - if (mfmaLayout == nullptr || dotOperandLayout == nullptr) - return false; - // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is - // improved. In addition, we can enable this shortcut for regular MFMA - // layout when opIdx == 1. - return mfmaLayout.getWarpsPerCTA()[1] == 1 && - dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && - dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] && - dotOperandLayout.getParent() == mfmaLayout && - (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && - (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); -} - // For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy) { @@ -738,8 +722,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !matchMmaV3AndDotOperandLayout(srcTy, dstTy) && - !isMfmaToDotShortcut(srcTy, dstTy); + !matchMmaV3AndDotOperandLayout(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index b7ee4efc72d0..d3ffaed2e8fc 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -115,56 +115,6 @@ struct LocalLoadOpConversion } }; -struct ConvertLayoutOpConversion - : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern< - triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = op.getSrc(); - Value dst = op.getResult(); - auto srcTy = cast(src.getType()); - auto dstTy = cast(dst.getType()); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - if (isa(srcLayout) && - isa(dstLayout)) { - return lowerMfmaToDotOperand(op, adaptor, rewriter); - } - return failure(); - } - -private: - LogicalResult - lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - RankedTensorType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - if (isMfmaToDotShortcut(srcTy, dstTy)) { - // vecSize is an number of sequential elements stored by one thread - // - For MFMA encoding (encoding of the result tensor of dot - // operation) it is 4 - // - For MFMA operand encoding it is - // dotOperandEncoding::kWidth, - // which is 4 in certain cases (e.g. fp16 and bfloat16 dtypes with kpack - // = 1) - // - // For cases where these two values are equal MFMA and MFMA operand - // layouts are the same. - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - Value view = - packLLElements(loc, getTypeConverter(), vals, rewriter, dstTy); - rewriter.replaceOp(op, view); - return success(); - } - return failure(); - } -}; } // namespace namespace mlir::triton::AMD { @@ -172,7 +122,6 @@ void populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index cece47227ea0..bce126ea4d72 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -38,8 +38,10 @@ struct DecomposeUnsupportedAMDConversions triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); - triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - isMfmaToDotShortcut); + auto isShortcut = + mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory)); + + triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut); /* -------------------------------- */ // Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op`