diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index df6029db0de2..cb3e3d292efa 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -212,7 +212,11 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); bool atomicNeedsSharedMemory(Value result); -bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT); + +bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); + +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); // Return true if the src and dst layout match. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, diff --git a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h index 8c7ab9831667..22c8f9c8a330 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -18,6 +18,14 @@ namespace gpu { SmallVector reorderValues(const SmallVector &values, Type inType, Type ouType); +SmallVector unpackI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter); + +SmallVector packI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter); + Type getElementType(Value value); class MultipleOperandsRange @@ -179,8 +187,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { for (auto operand : adaptor.getOperands()) { auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); - subOperands = unpackI32s(subOperands, argTy, rewriter, loc, - this->getTypeConverter()); + subOperands = unpackI32(subOperands, argTy, rewriter, loc, + this->getTypeConverter()); allOperands.resize(subOperands.size()); for (auto v : llvm::enumerate(subOperands)) allOperands[v.index()].push_back(v.value()); @@ -207,7 +215,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { } resultVals = maybeDeduplicate(op, resultVals); resultVals = - packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); + packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); Value view = packLLElements(loc, this->getTypeConverter(), resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 56a82d7cc0fb..29b8865c03ae 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1388,67 +1388,6 @@ inline Value getStructFromSharedMemoryObject(Location loc, return llvmStruct; } -// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer -// instructions to pack & unpack sub-word integers. A workaround is to -// store the results of tensors with dot operand encodings in i32 to -// facilitate instructions such as `ldmatrix`. -// -// TODO: Confirm if the problem is still there. -inline bool requiresI32Conversion(Type type) { - auto tensorTy = dyn_cast(type); - if (!tensorTy) - return false; - auto dotOpEnc = dyn_cast(tensorTy.getEncoding()); - if (!dotOpEnc) - return false; - auto parent = dyn_cast(dotOpEnc.getParent()); - if (!(parent && parent.getVersionMajor() < 3)) - return false; - return true; -} - -inline SmallVector packI32s(const SmallVector &inValues, - Type type, RewriterBase &rewriter, - Location loc, - const LLVMTypeConverter *typeConverter) { - if (!requiresI32Conversion(type)) - return inValues; - Type eltTy = - typeConverter->convertType(cast(type).getElementType()); - - SmallVector outValues; - int vecWidth = 32 / eltTy.getIntOrFloatBitWidth(); - auto vecTy = vec_ty(eltTy, vecWidth); - for (int i = 0; i < inValues.size(); i += vecWidth) { - Value vec = undef(vecTy); - for (int j = 0; j < vecWidth; j++) { - vec = insert_element(vec, inValues[i + j], i32_val(j)); - } - outValues.push_back(bitcast(vec, i32_ty)); - } - return outValues; -} - -inline SmallVector unpackI32s(const SmallVector &inValues, - Type type, RewriterBase &rewriter, - Location loc, - const LLVMTypeConverter *typeConverter) { - if (!requiresI32Conversion(type)) - return inValues; - Type eltTy = - typeConverter->convertType(cast(type).getElementType()); - - SmallVector outValues; - for (auto v : inValues) { - auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth()); - auto vec = bitcast(v, vecTy); - for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) { - outValues.push_back(extract_element(vec, i32_val(i))); - } - } - return outValues; -} - inline SmallVector unpackLLElements(Location loc, Value llvmStruct, RewriterBase &rewriter) { assert(bool(llvmStruct) && "can not unpack null values"); diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 47e3fca79bb1..c728cfbb32cf 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -679,13 +679,6 @@ class LinearLayout { // (i.e. every input bit affects the output). llvm::MapVector getFreeVariableMasks() const; - // Increase an input dimension without affecting the output dimension. The - // added free variables are mapped to 0, ensuring that the new input - // dimensions correspond directly to the existing output space. The function - // errors out if `newInDimSize` is less than the current size or the new size - // is not a power of 2. - LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const; - std::string toString() const; friend bool operator==(LinearLayout lhs, LinearLayout rhs); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 665b97aeebba..276a6e7004df 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - assert(cvtNeedsSharedMemory(srcTy, dstTy)); + assert(!isMfmaToDotShortcut(srcTy, dstTy)); // FIXME This is NOT entirely correct // This should be getElemOrder, but we don't have such a method diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index aa9f8b01eae1..4915d7b1acda 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } -bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { auto blockedLayout = dyn_cast(srcTy.getEncoding()); auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); if (blockedLayout == nullptr || dotOperandLayout == nullptr) @@ -605,6 +605,22 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { return matrixDimsCompatible && bDimCompatible; } +bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + auto mfmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (mfmaLayout == nullptr || dotOperandLayout == nullptr) + return false; + // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is + // improved. In addition, we can enable this shortcut for regular MFMA + // layout when opIdx == 1. + return mfmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && + dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] && + dotOperandLayout.getParent() == mfmaLayout && + (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && + (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); +} + // For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy) { @@ -639,46 +655,8 @@ std::optional minimalCvtLayout(RankedTensorType srcTy, toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); if (!(srcLayout.has_value() && dstLayout.has_value())) return std::nullopt; - StringAttr kRegister = StringAttr::get(ctx, "register"); - StringAttr kLane = StringAttr::get(ctx, "lane"); - StringAttr kWarp = StringAttr::get(ctx, "warp"); - StringAttr kBlock = StringAttr::get(ctx, "block"); - auto numSrcRegs = srcLayout->getInDimSize(kRegister); - auto numDstRegs = dstLayout->getInDimSize(kRegister); - // The `invertAndCompose` function will generate a layout that is injective - // by assigning new output dimensions to free variables. For instance, - // consider a scenario where `srcLayout` has a free variable in the lane - // dimension, while `dstLayout` has two free variables in the lane - // dimension and also a larger number of registers. - // The injective form of `srcLayout` will add only a single additional row - // to the transformation matrix, whereas the injective form of `dstLayout` - // will add two additional rows. This discrepancy causes misleading results - // because the matrices end up with a different number of rows. - // - // Take `dstLayout ⋅ srcLayout^-1` as an example: - // - // - `injective(dstLayout)`: [n, m] → [n + 2, m] - // - `injective(srcLayout)`: [n, m] → [n + 1, m] - // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1] - // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n + - // 1] → [n + 2, n + 1] - // - // Here, the `(n + 1)`-th row added by `dstLayout` represents the free - // variable in registers, and the `(n + 2)`-th row represents the free - // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout` - // represents the free variable in lanes. As a result, the `(n + 1)`-th row - // in two layouts do not correspond to the same free variable. - // - // To address this issue, we pad the free variables in `srcLayout` and - // `dstLayout` to ensure they have the same number of registers. This - // guarantees that the resulting matrices have the same number of rows, - // ensuring consistency in the composition process. - auto numRegs = std::max(numSrcRegs, numDstRegs); - auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs); - auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs); // comp describes the layout function to create dst from src. - LinearLayout comp = - dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs); + LinearLayout comp = dstLayout->invertAndCompose(*srcLayout); // We try to quotient by the largest subspace first auto dims = SmallVector{"block", "warp", "lane", "register"}; for (auto dim : dims) { @@ -715,14 +693,15 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { } bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { - // TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and - // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout - // checks. + // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`, + // `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully + // subsumed by the linear-layout checks. // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !matchMmaV3AndDotOperandLayout(srcTy, dstTy); + !isMmaToDotShortcut(srcTy, dstTy) && + !isMfmaToDotShortcut(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { @@ -732,6 +711,20 @@ bool atomicNeedsSharedMemory(Value value) { return true; } +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) + return true; + // dot_op = #mma + // when #mma = MmaEncoding + auto mmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 && + mmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && + dotOperandLayout.getParent() == mmaLayout && + !srcTy.getElementType().isF32(); +} + namespace { /// A data structure similar to SetVector but maintains diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 65ee8cc0023e..a18b2cbc308c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -288,90 +288,62 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return rewriter.notifyMatchFailure( op, "NYI. srcTy and/or dstTy don't implement LLs yet"); } - LinearLayout srcLayout = - *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - LinearLayout dstLayout = - *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); - - StringAttr kBlock = str_attr("block"); - StringAttr kWarp = str_attr("warp"); - StringAttr kLane = str_attr("lane"); - StringAttr kRegister = str_attr("register"); assert(to_vector(conversion->getInDimNames()) == to_vector(conversion->getOutDimNames())); auto dims = conversion->getInDimNames(); - if (llvm::is_contained(dims, kBlock)) { + if (llvm::is_contained(dims, str_attr("block"))) { // Case 1: Transfer between values in different CTAs. // This requires moving values through distributed shared memory. return rewriter.notifyMatchFailure( op, "NYI: Transfer between different CTAs"); - } else if (llvm::is_contained(dims, kWarp)) { + } else if (llvm::is_contained(dims, str_attr("warp"))) { // Case 2: Transfer between values in the same CTA, in which case we move // values through shared memory. + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, kLane)) { + } else if (llvm::is_contained(dims, str_attr("lane"))) { // Case 3. Transfer between values in the same warp, in which case we try // to move values using warp shuffles, though if the pattern is // complicated enough we may fall back to using shared memory // TODO(Keren): implement warp shuffle instead of using the general // approach that uses shared memory + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, kRegister) || - dstLayout.getInDimSize(kRegister) != - srcLayout.getInDimSize(kRegister)) { + } else if (llvm::is_contained(dims, str_attr("register"))) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). - return transferWithinThread( - op, dstLayout.getFreeVariableMasks()[kRegister], - dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter); + return transferWithinThread(op, *conversion, adaptor, rewriter); } else { - // Cast 5. The two layouts are equivalent. We should probably remove - // these in RemoveLayoutConversion. - auto dstCvt = requiresI32Conversion(dstTy); - auto srcCvt = requiresI32Conversion(srcTy); - if (dstCvt || srcCvt) { - auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter); - inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(), - getTypeConverter()); - inVals = - packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter()); - auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals, - rewriter, op.getType()); - rewriter.replaceOp(op, res); - } else { - rewriter.replaceOp(op, adaptor.getSrc()); - } + // The two layouts are equivalent. We should probably remove these in + // RemoveLayoutConversion. + rewriter.replaceOp(op, adaptor.getSrc()); return success(); } } LogicalResult - transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs, - const LinearLayout &conversion, OpAdaptor adaptor, + transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); StringAttr kRegister = str_attr("register"); assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); - auto srcTy = op.getSrc().getType(); - auto dstTy = op.getType(); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); - SmallVector outVals(numRegs); - for (int i = 0; i < numRegs; i++) { - // Remove free masks from the register index - // For example, if idx = 0b00111, and masks = 0b00100, then we get - // 0b00011. It means that register 7 (0b111) has the same value as - // register 3 (0b011). - auto idx = i & (~regMasks); - auto srcIdx = conversion.hasInDim(kRegister) - ? conversion.apply({{kRegister, idx}}).begin()->second - : idx; + SmallVector outVals; + outVals.resize(conversion.getInDimSize(kRegister)); + for (int i = 0; i < conversion.getInDimSize(kRegister); i++) { + auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; outVals[i] = inVals[srcIdx]; } - outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter()); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); @@ -403,6 +375,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (auto dotOperand = dyn_cast(layout)) { if (auto nvidiaMma = dyn_cast(dotOperand.getParent())) { + if (product(getCTAsPerCGA(nvidiaMma)) > 1) { + return false; + } if (useLegacyMMAConversion) { return false; } @@ -412,7 +387,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64; return largeKWidth && nvidiaMma.isAmpere(); } - return false; } if (isa(layout)) { return true; @@ -454,7 +428,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion inVals[it.index()] = ptrtoint(llvmElemTy, it.value()); } } - inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); // Pretty sure this is the identity function ATM // It'd be better to simply call `quotient({kBlock})` and @@ -474,7 +447,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } - outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter()); + // FIXME [Dot LL] + // We know it's just for largeKWidth case in Ampere + // In this case, we need to pack the outputs into i32 + if (isa(dstTy.getEncoding())) { + auto concat = [&](Value a, Value b) { + return or_(zext(i32_ty, bitcast(a, i16_ty)), + shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); + }; + + SmallVector outVals32(outVals.size() / 2); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); + } + outVals = outVals32; + } + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 470e8b32b540..8ee166866974 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -103,6 +103,51 @@ SmallVector reorderValues(const SmallVector &values, Type inType, llvm_unreachable("unimplemented code path"); } +SmallVector unpackI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return inValues; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return inValues; + SmallVector outValues; + for (auto v : inValues) { + // cast i32 to appropriate eltType vector and extract elements + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecType); + for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + +SmallVector packI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return inValues; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return inValues; + SmallVector outValues; + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); + auto vecType = vec_ty(eltType, vecWidth); + for (int i = 0; i < inValues.size(); i += vecWidth) { + Value vec = undef(vecType); + for (int j = 0; j < vecWidth; j++) { + vec = insert_element(vec, inValues[i + j], i32_val(j)); + } + outValues.push_back(bitcast(vec, i32_ty)); + } + return outValues; +} + int getNumElementsPerThreads(Type type, const LLVMTypeConverter *typeConverter) { int numElemsPerThread = 1; @@ -455,7 +500,7 @@ struct ElementwiseInlineAsmOpConversion auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); unpackedOperands.push_back( - unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter())); + unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter())); } int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), @@ -515,11 +560,10 @@ struct ElementwiseInlineAsmOpConversion unpackedResults[i], /*inType=*/op->getOperand(0).getType(), /*ouType=*/op->getResult(i).getType()); } - auto dstTy = op->getResult(i).getType(); - unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc, - getTypeConverter()); - outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i], - rewriter, op->getResult(i).getType())); + auto packed = packI32(unpackedResults[i], op->getResult(i).getType(), + rewriter, loc, getTypeConverter()); + outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter, + op->getResult(i).getType())); } rewriter.replaceOp(op, outs); diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index e2ed0228de8d..1a0c115a9ecf 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -184,7 +184,42 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { SmallVector outVals = loadSharedToDistributed( dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo); - outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter); + // FIXME [Dot LL] + // Ampere case + // In this case, we need to pack the outputs into i32 + if (auto dotOp = dyn_cast(dstTy.getEncoding())) { + if (auto parent = dyn_cast(dotOp.getParent())) { + if (parent.isAmpere()) { + if (elemLlvmTy.isInteger(8)) { + auto concat = [&](Value a1, Value a2, Value a3, Value a4) { + return or_( + or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))), + or_(shl(zext(i32_ty, a3), i32_val(16)), + shl(zext(i32_ty, a4), i32_val(24)))); + }; + SmallVector outVals32(outVals.size() / 4); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1], + outVals[4 * i + 2], outVals[4 * i + 3]); + } + outVals = outVals32; + } else { + assert(elemLlvmTy.isBF16() && "Unexpected element type"); + auto concat = [&](Value a, Value b) { + return or_(zext(i32_ty, bitcast(a, i16_ty)), + shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); + }; + + SmallVector outVals32(outVals.size() / 2); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); + } + outVals = outVals32; + } + } + } + } + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); rewriter.replaceOp(op, result); diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 6978ccfb2553..56af4eaef8b9 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -551,8 +551,8 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { } std::optional -mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, - ArrayRef shape) { +dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { // Current linear layout conversion for dot operand is only necessary to // enable LDS bypass for operand B in the MFMA dot path. To achieve @@ -895,7 +895,7 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); auto order = dot.getCTAOrder(); - assert(order[0] == rank - 1 && order[1] == rank - 2); + assert(order[0] == 1 && order[1] == 0); ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); @@ -903,11 +903,13 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, std::optional DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { - auto parent = getParent(); - if (auto mfmaLayout = llvm::dyn_cast(parent)) { - return mfmaDotToLinearLayout(*this, shape); - } else if (auto mma = mlir::dyn_cast(parent)) { - if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) { + if (auto mfmaLayout = llvm::dyn_cast(getParent())) { + return dotOperandMfmaToLinearLayout(*this, shape); + } else if (auto mma = mlir::dyn_cast(getParent())) { + // FIXME [Dot LL] + // Do this unconditionally + auto largeKWidth = getKWidth() == 8; + if (mma.isAmpere() && largeKWidth) { return ampereDotToLinearLayout(shape, *this); } } diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 9f3d8fff491b..6d8279795209 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -290,7 +290,7 @@ struct MMAV3UseRegOperand dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); - if (!matchMmaV3AndDotOperandLayout(srcTy, newTy)) + if (!isMmaToDotShortcut(srcTy, newTy)) return failure(); Value newOperand = diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 4319d1f086dd..bf017f8c6463 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -1016,21 +1016,6 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const { return true; } -LinearLayout LinearLayout::resize(StringAttr inDim, - int32_t newInDimSize) const { - BasesT bases = getBases(); - assert(bases.contains(inDim) && "inDim not in layout"); - assert(llvm::isPowerOf2_32(newInDimSize) && - "newInDimSize must be a power of 2"); - assert(newInDimSize >= getInDimSize(inDim) && - "newInDimSize must be >= old size"); - auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim); - for (int i = 0; i < numFreeVariables; i++) { - bases[inDim].push_back(std::vector(getNumOutDims(), 0)); - } - return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames())); -} - std::string LinearLayout::toString() const { // Start with a newline because we print out a bulleted list; it doesn't // make sense for the first line of this list to be on the same line as diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 208f6b80bfe5..a0719c974f9c 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -6,7 +6,7 @@ #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> #B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index df4f5ab01feb..2054853b30c1 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -5,7 +5,7 @@ #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index a2f713faaf18..72e02a4ef46e 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -67,15 +67,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_f16x2 - tt.func @atomic_add_f16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + // CHECK-LABEL: atomic_add_f16 + tt.func @atomic_add_f16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> // CHECK: llvm.cond_br - // CHECK-NOT: rocdl.update.dpp // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> - // CHECK-NOT: rocdl.update.dpp %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> tt.return } @@ -85,51 +83,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_bf16x2 - tt.func @atomic_add_bf16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + // CHECK-LABEL: atomic_add_bf16 + tt.func @atomic_add_bf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> // CHECK: llvm.cond_br - // CHECK-NOT: rocdl.update.dpp // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> - // CHECK-NOT: rocdl.update.dpp - %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> - tt.return - } -} - -// ----- - -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_f16_dpp - tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { - %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> - %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> - %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> - // CHECK: llvm.cond_br - // CHECK: rocdl.update.dpp - // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> - // CHECK: rocdl.update.dpp - %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> - tt.return - } -} - -// ----- - -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_bf16_dpp - tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { - %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> - %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> - %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> - // CHECK: llvm.cond_br - // CHECK: rocdl.update.dpp - // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> - // CHECK: rocdl.update.dpp %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> tt.return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 325c425a2277..34573f7739b8 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -821,110 +821,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: convert_layout_mmav2_dot_reg - tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> - tt.return - } -} - -// ----- - -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: convert_layout_mmav2_dot_reg - tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<1x16xf16, #mma>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1> - tt.return - } -} - -// ----- - -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#slice = #triton_gpu.slice<{dim = 0, parent = #mma}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_slice_mmav2_blocked_reg - tt.func @convert_layout_slice_mmav2_blocked_reg(%arg0: tensor<1xf16, #slice>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_0 - tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_1 - tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_2 - tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_3 - tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> - tt.return - } -} - -// ----- - #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { @@ -949,80 +845,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: convert_layout_mmav2_dot_reg - tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_0 - tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_1 - tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_2 - tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_3 - tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> - tt.return - } -} - -// ----- - #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 5dfd0f2a5f4c..686e5a24e8dd 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -460,6 +460,429 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } } +// ----- +// This test ensures that loads will not be moved across `for` loops. + +// CHECK-LABEL: tt.func public @_attn_bwd +// CHECK: tt.load +// CHECK: tt.load +// CHECK: scf.for +// CHECK: } +// CHECK: scf.for +// CHECK: } +// Moved before the independent `tt.store` ops but not before the `for` ops. +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.store +// CHECK: tt.store +// CHECK: scf.for +// CHECK: } +// CHECK: scf.for +// CHECK: } +// CHECK: tt.store + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#mma1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> +#shared3 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @_attn_bwd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma> + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %cst_2 = arith.constant dense<0.693147182> : tensor<128x64xf32, #mma> + %0 = tt.get_program_id z : i32 + %1 = arith.muli %0, %arg14 : i32 + %2 = arith.extsi %1 : i32 to i64 + %3 = arith.remsi %0, %arg13 : i32 + %4 = arith.muli %arg11, %3 : i32 + %5 = arith.divsi %0, %arg13 : i32 + %6 = arith.muli %arg10, %5 : i32 + %7 = arith.addi %4, %6 : i32 + %8 = arith.extsi %7 : i32 to i64 + %9 = tt.get_program_id x : i32 + %10 = tt.addptr %arg0, %8 : !tt.ptr, i64 + %11 = tt.addptr %arg1, %8 : !tt.ptr, i64 + %12 = tt.addptr %arg2, %8 : !tt.ptr, i64 + %13 = tt.addptr %arg4, %8 : !tt.ptr, i64 + %14 = tt.addptr %arg5, %8 : !tt.ptr, i64 + %15 = tt.addptr %arg6, %8 : !tt.ptr, i64 + %16 = tt.addptr %arg7, %8 : !tt.ptr, i64 + %17 = tt.addptr %arg8, %2 : !tt.ptr, i64 + %18 = tt.addptr %arg9, %2 : !tt.ptr, i64 + %19 = arith.muli %9, %c128_i32 : i32 + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %25 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %26 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %27 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %28 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %29 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %30 = arith.addi %25, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %31 = arith.addi %26, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %32 = arith.addi %27, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %33 = arith.addi %28, %23 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %34 = arith.addi %29, %24 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %35 = tt.expand_dims %30 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> + %36 = tt.expand_dims %31 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %37 = tt.expand_dims %32 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> + %38 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #mma> + %39 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #blocked> + %40 = arith.muli %35, %38 : tensor<128x1xi32, #mma> + %41 = arith.muli %36, %39 : tensor<128x1xi32, #blocked> + %42 = tt.splat %11 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %43 = tt.addptr %42, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %45 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %46 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %47 = tt.expand_dims %44 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> + %48 = tt.expand_dims %45 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %49 = tt.expand_dims %46 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %50 = tt.broadcast %43 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %51 = tt.broadcast %47 : tensor<1x64xi32, #mma> -> tensor<128x64xi32, #mma> + %52 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> + %53 = tt.addptr %50, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %54 = tt.load %53 : tensor<128x64x!tt.ptr, #blocked> + %55 = tt.splat %12 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %56 = tt.addptr %55, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %57 = tt.broadcast %56 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %58 = tt.addptr %57, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %59 = tt.load %58 : tensor<128x64x!tt.ptr, #blocked> + %60 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %61 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %62 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %63 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %64 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %65 = arith.addi %63, %60 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %66 = arith.addi %64, %62 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %67 = tt.expand_dims %65 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi32, #blocked2> + %68 = tt.splat %arg12 : i32 -> tensor<1x16xi32, #blocked2> + %69 = arith.muli %67, %68 : tensor<1x16xi32, #blocked2> + %70 = tt.splat %10 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> + %71 = tt.addptr %70, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> + %72 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %73 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %74 = tt.expand_dims %72 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1xi32, #blocked2> + %75 = tt.expand_dims %73 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi32, #blocked3> + %76 = tt.broadcast %71 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> + %77 = tt.broadcast %74 : tensor<64x1xi32, #blocked2> -> tensor<64x16xi32, #blocked2> + %78 = tt.addptr %76, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %79 = tt.expand_dims %66 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> + %80 = tt.splat %arg12 : i32 -> tensor<16x1xi32, #blocked1> + %81 = arith.muli %79, %80 : tensor<16x1xi32, #blocked1> + %82 = tt.splat %13 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked1> + %83 = tt.addptr %82, %81 : tensor<16x1x!tt.ptr, #blocked1>, tensor<16x1xi32, #blocked1> + %84 = tt.broadcast %83 : tensor<16x1x!tt.ptr, #blocked1> -> tensor<16x64x!tt.ptr, #blocked1> + %85 = tt.broadcast %49 : tensor<1x64xi32, #blocked1> -> tensor<16x64xi32, #blocked1> + %86 = tt.addptr %84, %85 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> + %87 = tt.splat %17 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %88 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> + %89 = tt.splat %18 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %90 = arith.muli %arg12, %c16_i32 : i32 + %91 = tt.splat %90 : i32 -> tensor<64x16xi32, #blocked2> + %92 = tt.splat %90 : i32 -> tensor<16x64xi32, #blocked1> + %93:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %cst_1, %arg18 = %19, %arg19 = %78, %arg20 = %86) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1>) : i32 { + %206 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> + %207 = tt.splat %arg18 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %208 = arith.addi %207, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %209 = tt.addptr %87, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %210 = tt.load %209 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %213 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %215 = tt.dot %212, %214, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> + %217 = tt.broadcast %216 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> + %218 = arith.subf %215, %217 : tensor<128x16xf32, #mma1> + %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> + %220 = tt.expand_dims %208 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> + %221 = tt.broadcast %220 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> + %222 = arith.cmpi sge, %221, %88 : tensor<128x16xi32, #mma1> + %223 = arith.select %222, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> + %224 = tt.load %arg20 : tensor<16x64x!tt.ptr, #blocked1> + %225 = arith.truncf %223 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %226 = triton_gpu.local_alloc %225 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> + %227 = triton_gpu.local_load %226 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %228 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> + %229 = triton_gpu.local_load %228 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %230 = tt.dot %227, %229, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %231 = tt.addptr %89, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %232 = tt.load %231 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %233 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %234 = tt.trans %233 {order = array} : !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %235 = triton_gpu.local_load %234 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %236 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %237 = triton_gpu.local_load %236 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %238 = tt.dot %237, %235, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %239 = tt.expand_dims %232 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> + %240 = tt.broadcast %239 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> + %241 = arith.subf %238, %240 : tensor<128x16xf32, #mma1> + %242 = arith.mulf %223, %241 : tensor<128x16xf32, #mma1> + %243 = arith.truncf %242 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %244 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> + %245 = tt.trans %244 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> + %246 = triton_gpu.local_load %245 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %247 = triton_gpu.local_alloc %243 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> + %248 = triton_gpu.local_load %247 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %249 = tt.dot %248, %246, %arg17 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %250 = arith.addi %arg18, %c16_i32 : i32 + %251 = tt.addptr %arg19, %91 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %252 = tt.addptr %arg20, %92 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> + scf.yield %230, %249, %250, %251, %252 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1> + } + %94 = arith.addi %19, %c128_i32 : i32 + %95 = arith.subi %arg14, %94 : i32 + %96 = arith.divsi %95, %c32_i32 : i32 + %97 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %98 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %99 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %100 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %101 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %102 = arith.addi %100, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %103 = arith.addi %101, %99 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %104 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %105 = tt.splat %arg12 : i32 -> tensor<1x32xi32, #blocked3> + %106 = arith.muli %104, %105 : tensor<1x32xi32, #blocked3> + %107 = tt.splat %10 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> + %108 = tt.addptr %107, %106 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> + %109 = tt.broadcast %108 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> + %110 = tt.broadcast %75 : tensor<64x1xi32, #blocked3> -> tensor<64x32xi32, #blocked3> + %111 = tt.addptr %109, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %112 = tt.expand_dims %103 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %113 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked> + %114 = arith.muli %112, %113 : tensor<32x1xi32, #blocked> + %115 = tt.splat %13 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %116 = tt.addptr %115, %114 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %117 = tt.broadcast %116 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x64x!tt.ptr, #blocked> + %118 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %119 = tt.addptr %117, %118 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + %120 = tt.splat %17 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %121 = tt.splat %18 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %122 = arith.muli %arg12, %c32_i32 : i32 + %123 = tt.splat %122 : i32 -> tensor<64x32xi32, #blocked3> + %124 = tt.splat %122 : i32 -> tensor<32x64xi32, #blocked> + %125:5 = scf.for %arg15 = %c0_i32 to %96 step %c1_i32 iter_args(%arg16 = %93#0, %arg17 = %93#1, %arg18 = %94, %arg19 = %111, %arg20 = %119) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked>) : i32 { + %206 = tt.load %arg19 : tensor<64x32x!tt.ptr, #blocked3> + %207 = tt.splat %arg18 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %208 = arith.addi %207, %98 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %209 = tt.addptr %120, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %210 = tt.load %209 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %213 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %215 = tt.dot %212, %214, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> + %217 = tt.broadcast %216 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> + %218 = arith.subf %215, %217 : tensor<128x32xf32, #mma> + %219 = math.exp2 %218 : tensor<128x32xf32, #mma> + %220 = tt.load %arg20 : tensor<32x64x!tt.ptr, #blocked> + %221 = arith.truncf %219 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> + %222 = triton_gpu.convert_layout %221 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %223 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> + %224 = triton_gpu.local_load %223 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %225 = tt.dot %222, %224, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %226 = tt.addptr %121, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %227 = tt.load %226 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %228 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> + %229 = tt.trans %228 {order = array} : !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> + %230 = triton_gpu.local_load %229 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %231 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %232 = triton_gpu.local_load %231 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %233 = tt.dot %232, %230, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %234 = tt.expand_dims %227 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> + %235 = tt.broadcast %234 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> + %236 = arith.subf %233, %235 : tensor<128x32xf32, #mma> + %237 = arith.mulf %219, %236 : tensor<128x32xf32, #mma> + %238 = arith.truncf %237 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> + %239 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> + %240 = tt.trans %239 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> + %241 = triton_gpu.local_load %240 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %242 = triton_gpu.convert_layout %238 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %243 = tt.dot %242, %241, %arg17 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %244 = arith.addi %arg18, %c32_i32 : i32 + %245 = tt.addptr %arg19, %123 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %246 = tt.addptr %arg20, %124 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + scf.yield %225, %243, %244, %245, %246 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked> + } + %126 = tt.splat %16 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> + %127 = tt.addptr %126, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> + %128 = tt.broadcast %127 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> + %129 = tt.addptr %128, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> + %130 = arith.truncf %125#0 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + tt.store %129, %130 : tensor<128x64x!tt.ptr, #mma> + %131 = tt.splat %arg3 : f32 -> tensor<128x64xf32, #mma> + %132 = arith.mulf %125#1, %131 : tensor<128x64xf32, #mma> + %133 = tt.splat %15 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> + %134 = tt.addptr %133, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> + %135 = tt.broadcast %134 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> + %136 = tt.addptr %135, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> + %137 = arith.truncf %132 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + tt.store %136, %137 : tensor<128x64x!tt.ptr, #mma> + %138 = tt.splat %10 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %139 = tt.addptr %138, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %140 = tt.broadcast %139 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %141 = tt.addptr %140, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %142 = tt.load %141 : tensor<128x64x!tt.ptr, #blocked> + %143 = tt.splat %13 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %144 = tt.addptr %143, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %145 = tt.broadcast %144 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %146 = tt.addptr %145, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %147 = tt.load %146 : tensor<128x64x!tt.ptr, #blocked> + %148 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %149 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %150 = tt.addptr %148, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %151 = tt.addptr %149, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %152 = tt.load %150 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %153 = tt.load %151 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %154 = tt.expand_dims %152 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> + %155 = tt.expand_dims %153 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %156 = tt.splat %11 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> + %157 = tt.addptr %156, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> + %158 = tt.broadcast %157 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> + %159 = tt.addptr %158, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %160 = tt.splat %12 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> + %161 = tt.addptr %160, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> + %162 = tt.broadcast %161 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> + %163 = tt.addptr %162, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %164 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %165 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %166 = tt.addptr %164, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %167 = tt.addptr %165, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %168 = tt.load %166 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %169 = tt.load %167 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %170 = tt.broadcast %154 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> + %171 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> + %172 = tt.expand_dims %168 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> + %173 = tt.broadcast %172 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> + %174 = arith.muli %arg12, %c16_i32 : i32 + %175 = tt.splat %174 : i32 -> tensor<64x16xi32, #blocked2> + %176 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + %177:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %19, %arg18 = %159, %arg19 = %163, %arg20 = %c-1_i32) -> (tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32) : i32 { + %206 = arith.addi %arg20, %c1_i32 : i32 + %207 = arith.cmpi slt, %206, %c1_i32 : i32 + %208 = arith.select %207, %206, %c0_i32 : i32 + %209 = tt.load %arg18 : tensor<64x16x!tt.ptr, #blocked2> + %210 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> + %211 = triton_gpu.memdesc_subview %176[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %210, %211 : tensor<64x16xf16, #blocked2> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + %212 = triton_gpu.local_load %211 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %215 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %216 = triton_gpu.local_load %215 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %217 = tt.dot %214, %216, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %218 = arith.subf %217, %170 : tensor<128x16xf32, #mma1> + %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> + %220 = tt.splat %arg17 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %221 = arith.addi %220, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %222 = tt.expand_dims %221 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> + %223 = tt.broadcast %222 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> + %224 = arith.cmpi sge, %171, %223 : tensor<128x16xi32, #mma1> + %225 = arith.select %224, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> + %226 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %227 = triton_gpu.local_load %226 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %228 = tt.dot %227, %212, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %229 = arith.subf %228, %173 : tensor<128x16xf32, #mma1> + %230 = arith.mulf %225, %229 : tensor<128x16xf32, #mma1> + %231 = arith.truncf %230 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %232 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> + %233 = tt.trans %232 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> + %234 = triton_gpu.local_load %233 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %235 = triton_gpu.local_alloc %231 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> + %236 = triton_gpu.local_load %235 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %237 = tt.dot %236, %234, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %238 = arith.addi %arg17, %c16_i32 : i32 + %239 = tt.addptr %arg18, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %240 = tt.addptr %arg19, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + scf.yield %237, %238, %239, %240, %208 : tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32 + } + triton_gpu.local_dealloc %176 : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + %178 = arith.divsi %19, %c32_i32 : i32 + %179 = arith.muli %178, %c32_i32 : i32 + %180 = arith.subi %19, %179 : i32 + %181 = tt.splat %180 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %182 = arith.addi %181, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %183 = tt.expand_dims %182 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %184 = arith.muli %183, %105 : tensor<1x32xi32, #blocked3> + %185 = tt.splat %11 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> + %186 = tt.addptr %185, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> + %187 = tt.broadcast %186 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> + %188 = tt.addptr %187, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %189 = tt.splat %12 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> + %190 = tt.addptr %189, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> + %191 = tt.broadcast %190 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> + %192 = tt.addptr %191, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %193 = tt.broadcast %155 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> + %194 = tt.expand_dims %169 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %195 = tt.broadcast %194 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> + %196 = arith.muli %arg12, %c32_i32 : i32 + %197 = tt.splat %196 : i32 -> tensor<64x32xi32, #blocked3> + %198 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + %199:4 = scf.for %arg15 = %c0_i32 to %178 step %c1_i32 iter_args(%arg16 = %177#0, %arg17 = %188, %arg18 = %192, %arg19 = %c-1_i32) -> (tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32) : i32 { + %206 = arith.addi %arg19, %c1_i32 : i32 + %207 = arith.cmpi slt, %206, %c1_i32 : i32 + %208 = arith.select %207, %206, %c0_i32 : i32 + %209 = tt.load %arg17 : tensor<64x32x!tt.ptr, #blocked3> + %210 = tt.load %arg18 : tensor<64x32x!tt.ptr, #blocked3> + %211 = triton_gpu.memdesc_subview %198[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %210, %211 : tensor<64x32xf16, #blocked3> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + %212 = triton_gpu.local_load %211 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %215 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> + %216 = triton_gpu.local_load %215 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %217 = tt.dot %214, %216, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %218 = arith.subf %217, %193 : tensor<128x32xf32, #mma> + %219 = math.exp2 %218 : tensor<128x32xf32, #mma> + %220 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %221 = triton_gpu.local_load %220 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %222 = tt.dot %221, %212, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %223 = arith.subf %222, %195 : tensor<128x32xf32, #mma> + %224 = arith.mulf %219, %223 : tensor<128x32xf32, #mma> + %225 = arith.truncf %224 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> + %226 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> + %227 = tt.trans %226 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> + %228 = triton_gpu.local_load %227 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %229 = triton_gpu.convert_layout %225 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %230 = tt.dot %229, %228, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %231 = tt.addptr %arg17, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %232 = tt.addptr %arg18, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + scf.yield %230, %231, %232, %208 : tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32 + } + triton_gpu.local_dealloc %198 : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + %200 = tt.splat %14 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> + %201 = tt.addptr %200, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> + %202 = tt.broadcast %201 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> + %203 = tt.addptr %202, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> + %204 = arith.mulf %199#0, %cst_2 : tensor<128x64xf32, #mma> + %205 = arith.truncf %204 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + tt.store %203, %205 : tensor<128x64x!tt.ptr, #mma> + tt.return + } +} + // ----- // CHECK-LABEL: sink_convert_dealloc diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir deleted file mode 100644 index 5c173ffb4858..000000000000 --- a/test/TritonGPU/amd/amd-sched-2nd-load.mlir +++ /dev/null @@ -1,211 +0,0 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s - -// Check the logic of sched-2nd-load optimizations -// - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> - -// Category 1: Single dot with two loads, we make sure the optimization is applied when tile size is large enough -// The following tile sizes should apply the optimization -// 256x256x128 -// 256x256x64 -// The following tile sizes should NOT apply the optimization -// 256x64x128 -// 256x256x32 -// - -// Should apply: tile size 256x256x128 with single dot -// CHECK-LABEL: sink_2nd_load_256x256x128 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} - -// Should apply: tile size 256x256x64 with single dot -// CHECK-LABEL: sink_2nd_load_256x256x64 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} - -// Should NOT apply: tile size 256x64x128 with single dot -// CHECK-LABEL: sink_2nd_load_256x64x128 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x64xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> - tt.return - } -} - -// Should NOT apply: tile size 256x256x32 with single dot -// CHECK-LABEL: sink_2nd_load_256x256x32 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} - - -// Category 2: single dot with two loads and tile size is large enough (128x128x128). -// We make sure the move is legal. -// Should NOT apply: the 2nd load has a user before the dot -// CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: tt.store -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> - %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> - tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> - %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> - triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<128x128xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> - tt.return - } -} - - -// ----- - -// Category 3: two dots in the for loop. Make sure the optimization is not applied -// should NOT apply: two dots -// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot -// CHECK: triton_gpu.local_load -// CHECK-NEXT: triton_gpu.local_load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: tt.dot -// CHECK-NEXT: tt.load -// CHECK-NEXT: tt.load -// CHECK-NEXT: triton_gpu.local_store -// CHECK-NEXT: triton_gpu.local_store -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %6 = tt.dot %1, %2, %3 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index d3ffaed2e8fc..b7ee4efc72d0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -115,6 +115,56 @@ struct LocalLoadOpConversion } }; +struct ConvertLayoutOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = op.getSrc(); + Value dst = op.getResult(); + auto srcTy = cast(src.getType()); + auto dstTy = cast(dst.getType()); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (isa(srcLayout) && + isa(dstLayout)) { + return lowerMfmaToDotOperand(op, adaptor, rewriter); + } + return failure(); + } + +private: + LogicalResult + lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + if (isMfmaToDotShortcut(srcTy, dstTy)) { + // vecSize is an number of sequential elements stored by one thread + // - For MFMA encoding (encoding of the result tensor of dot + // operation) it is 4 + // - For MFMA operand encoding it is + // dotOperandEncoding::kWidth, + // which is 4 in certain cases (e.g. fp16 and bfloat16 dtypes with kpack + // = 1) + // + // For cases where these two values are equal MFMA and MFMA operand + // layouts are the same. + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value view = + packLLElements(loc, getTypeConverter(), vals, rewriter, dstTy); + rewriter.replaceOp(op, view); + return success(); + } + return failure(); + } +}; } // namespace namespace mlir::triton::AMD { @@ -122,6 +172,7 @@ void populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index bce126ea4d72..cece47227ea0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -38,10 +38,8 @@ struct DecomposeUnsupportedAMDConversions triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); - auto isShortcut = - mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory)); - - triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut); + triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, + isMfmaToDotShortcut); /* -------------------------------- */ // Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op` diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 343dc7b3f37f..5265f631ad9e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -694,32 +694,6 @@ struct AtomicCASOpConversion } }; -bool supportsGlobalAtomicF16PackedAndDpp(triton::AMD::ISAFamily isaFamily) { - return isaFamily == triton::AMD::ISAFamily::CDNA1 || - isaFamily == triton::AMD::ISAFamily::CDNA2 || - isaFamily == triton::AMD::ISAFamily::CDNA3; -} - -Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) { - assert(val.getType().isInteger(32)); - auto loc = val.getLoc(); - Value old = i32_val(0); - int rowMask = 0b1111; // enable all rows - int bankMask = 0b1111; // enable all banks - bool boundCtrl = false; - auto dppMovOp = rewriter.create( - loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl); - return dppMovOp.getResult(); -} - -Value shiftLeftI32ByDpp(PatternRewriter &rewriter, Value val) { - return generateI32DppMove(rewriter, val, 0x101); // shift left 1 lane -} - -Value shiftRightI32ByDpp(PatternRewriter &rewriter, Value val) { - return generateI32DppMove(rewriter, val, 0x111); // shift right 1 lane -} - struct AtomicRMWOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { @@ -791,36 +765,14 @@ struct AtomicRMWOpConversion // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); int numElems = 1; - Type packF16Ty = vec_ty(valueElemTy, 2); - - // In the case of unpaired f16 elements utilize dpp instructions to - // accelerate atomics. Here is an algorithm of lowering - // tt::atomicRmwOp(%ptr, %val, %mask): - // 0. Group thread by pairs. Master thread is (tid % 2 == 0); - // 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so - // all the masters recieve value from secondary threads; - // 2. Take into account parity in the %mask value, build control flow - // structures according to it; - // 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value; - // 4. All the threads send result of generated operation to (tid + 1) thread - // via dppUpdateOp shl, so all secondary thread also recieve their - // result. - // - // This approach enables us to use half the active threads committing atomic - // requests to avoid generating of code providing unified access to f16 - // element and reduce contantion. - bool useDppForPackedF16 = false; // tensor if (tensorTy) { auto valTy = cast(val.getType()); - bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16(); - unsigned availableVecSize = isF16Ty ? 2 : 1; - vec = std::min(vec, availableVecSize); - // Force F16 packing in the case it's not comming in as packed, but the - // ISA can support packed atomic instructions. - useDppForPackedF16 = - supportsGlobalAtomicF16PackedAndDpp(targetInfo.getISAFamily()) && - vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD; + Type elTy = valTy.getElementType(); + vec = std::min(vec, llvm::isa(elTy) && + elTy.getIntOrFloatBitWidth() == 16 + ? 2 + : 1); // mask numElems = tensorTy.getNumElements(); } @@ -828,15 +780,12 @@ struct AtomicRMWOpConversion auto tid = tid_val(); mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); - if (useDppForPackedF16) - mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0))); auto memOrdering = op.getSem(); auto atomicMemOrdering = getMemoryOrdering(memOrdering); auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; - retType = useDppForPackedF16 ? packF16Ty : retType; SmallVector resultVals(elemsPerThread); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwPtr = ptrElements[i]; @@ -845,24 +794,7 @@ struct AtomicRMWOpConversion Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; Value operand; - if (useDppForPackedF16) { - // Move %val to left neighbour to proceed packed atomic further. - Value packedVal = null(packF16Ty); - packedVal = - insert_element(packF16Ty, packedVal, valElements[i], i32_val(0)); - // Pack to i32 type to simplify transaction - packedVal = bitcast(packedVal, i32_ty); - Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal); - // Unpack results back - Value unpackedDppRes = bitcast(dppMoveRes, packF16Ty); - operand = undef(packF16Ty); - operand = - insert_element(packF16Ty, operand, valElements[i], i32_val(0)); - operand = insert_element( - packF16Ty, operand, - extract_element(valueElemTy, unpackedDppRes, i32_val(0)), - i32_val(1)); - } else if (vec == 1) { + if (vec == 1) { operand = valElements[i]; } else { operand = undef(vecTy); @@ -904,25 +836,10 @@ struct AtomicRMWOpConversion rewriter.setInsertionPointToStart(endBlock); Value retVal = endBlock->getArgument(0); if (tensorTy) { - if (useDppForPackedF16) { - // Return packed to i32 result after atomic operation back from master - // lane. - auto packedRet = bitcast(retVal, i32_ty); - Value dppMovRes = shiftRightI32ByDpp(rewriter, packedRet); - // Unpack results back - Value unpackedDppRes = bitcast(dppMovRes, packF16Ty); - retVal = insert_element( - packF16Ty, retVal, - extract_element(valueElemTy, unpackedDppRes, i32_val(1)), - i32_val(1)); - resultVals[i] = - extract_element(valueElemTy, retVal, urem(tid, i32_val(2))); - } else { - for (int ii = 0; ii < vec; ++ii) { - resultVals[i + ii] = - vec == 1 ? retVal - : extract_element(valueElemTy, retVal, i32_val(ii)); - } + for (int ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = + vec == 1 ? retVal + : extract_element(valueElemTy, retVal, i32_val(ii)); } } else { if (!atomicNeedsSharedMemory(op.getResult())) { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 9371c8b5f897..22349c50e308 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -20,15 +20,13 @@ namespace ttg = mlir::triton::gpu; // Return true if the given moduleOp contains a pure matmul problem; i.e., // single dot in the main loop. static bool isPureMatmulProblem(ModuleOp moduleOp) { - bool isMatmul = true; - bool foundLoop = false; - moduleOp.walk([&](scf::ForOp forOp) -> void { + for (auto forOp : moduleOp.getOps()) { int counter = 0; forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); - isMatmul = (isMatmul && (counter == 1)); - foundLoop = true; - }); - return foundLoop && isMatmul; + if (counter != 1) + return false; + } + return true; } // Search through block to find earliest insertion point for move op. This can @@ -269,82 +267,6 @@ static void scheduleGlobalLoadLocalStore(ModuleOp m) { } } -/** - * Sched-load optimization for matmul kernels with large tile sizes - * The basic idea of sched-load optimization is to sink the 2nd tt.load - * after local_load so that global_load instructions can be interleaved with - * mfma's. This can help hide the issue latency of global_load instructions - * and improve performance on MI300X. - * - * It's assumed that the IR before this optimization has the following - * structure: - * ```mlir - * scf.for .. - * { - * tileA = tt.load a_ptr - * tileB = tt.load b_ptr - * opA = local_load bufferA - * opB = local_load bufferB - * res = tt.dot opA, opB - * local_store tileA, bufferA - * local_store tileB, bufferB - * } - * ``` - * After this optimization, the IR is transformed to - * ```mlir - * scf.for .. - * { - * tileA = tt.load a_ptr - * opA = local_load bufferA - * opB = local_load bufferB - * tileB = tt.load b_ptr <-- 2nd tt.load is sinked here - * res = tt.dot opA, opB - * local_store tileA, bufferA - * local_store tileB, bufferB - * } - * ``` - * For now, we don't have a perfect hueristic about when should this - * optimization be applied. Therefore, we implement a simple hueristic that - * this is applied when the tile size of A and B are large enough, i.e. - * nonKDim >= 128 and kDim >= 64. And also this is only applied for typical - * matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We - * are experimenting how to better control instruction scheduling and enable - * such optimizations. - */ -static void sinkSecondLoad(ModuleOp m) { - m.walk([&](scf::ForOp forOp) -> void { - SetVector loadOps; - triton::DotOp dotOp; - for (Operation &op : forOp) { - if (auto loadOp = dyn_cast(&op)) - loadOps.insert(loadOp); - if (auto curOp = dyn_cast(&op)) - dotOp = curOp; - } - // Only apply the optimization when there are 2 load's in the loop - if (loadOps.size() != 2) - return; - // Only apply the optimization when tile size is large enough - // 1. nonKDim >= 128 - // 2. kDim >= 64 - auto ldAOp = loadOps[0]; - auto tileAShape = cast(ldAOp.getType()).getShape(); - auto ldBOp = loadOps[1]; - auto tileBShape = cast(ldBOp.getType()).getShape(); - if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128)) - return; - // Only apply the optimization when the moving is legal - // 1. Make sure the 2nd loadOp is before the dot - // 2. Make sure the first user of the 2nd loadOp is after the dot. - bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp); - auto firstUser = *ldBOp.getResult().getUsers().begin(); - bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser); - if (isBeforeDotOp && firstUserAfterDotOp) - // move ldBOp right before tt.dot - ldBOp->moveBefore(dotOp); - }); -} - //===----------------------------------------------------------------------===// // Pass definition //===----------------------------------------------------------------------===// @@ -366,10 +288,8 @@ struct TritonAMDGPUReorderInstructionsPass moveUpTranspose(m); - if (isPureMatmulProblem(m)) { + if (isPureMatmulProblem(m)) scheduleGlobalLoadLocalStore(m); - sinkSecondLoad(m); - } } }; } // namespace diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 96289bbb2e47..71fd3c0cd4e7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -631,6 +631,46 @@ struct ConvertLayoutOpConversion convertMMAV3To8BitsDotOperand(op, adaptor, rewriter); return success(); } + + if (isMmaToDotShortcut(srcTy, dstTy)) { + // get source values + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + unsigned elems = getTotalElemsPerThread(srcTy); + Type elemTy = + this->getTypeConverter()->convertType(srcTy.getElementType()); + // for the destination type, we need to pack values together + // so they can be consumed by tensor core operations + SmallVector vecVals; + // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer + // instructions to pack & unpack sub-word integers. A workaround is to + // store the results of ldmatrix in i32 + auto elemSize = elemTy.getIntOrFloatBitWidth(); + if (auto intTy = dyn_cast(elemTy) && elemSize <= 16) { + auto fold = 32 / elemSize; + for (unsigned i = 0; i < elems; i += fold) { + Value val = i32_val(0); + for (unsigned j = 0; j < fold; j++) { + auto ext = + shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); + val = or_(i32_ty, val, ext); + } + vecVals.push_back(bitcast(val, i32_ty)); + } + } else { + unsigned vecSize = std::max(32 / elemSize, 1); + Type vecTy = vec_ty(elemTy, vecSize); + for (unsigned i = 0; i < elems; i += vecSize) { + Value packed = rewriter.create(loc, vecTy); + for (unsigned j = 0; j < vecSize; j++) + packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); + vecVals.push_back(bitcast(packed, i32_ty)); + } + } + Value view = + packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy); + rewriter.replaceOp(op, view); + return success(); + } return failure(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index 36b14e270b27..cf0ddc248dd1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -70,18 +70,10 @@ struct DecomposeUnsupportedConversions : public mlir::triton::impl::DecomposeUnsupportedNVIDIAConversionsBase< DecomposeUnsupportedConversions> { void runOnOperation() override { - // FIXME [Dot LL] - // Remove the decomposeTensorCoreToDotLayoutConversion class entirely after - // we have enabled the new layout conversion for all the cases. - auto nvidiaShortCutFn = [&](RankedTensorType srcTy, - RankedTensorType dstTy) { - return matchMmaV3AndDotOperandLayout(srcTy, dstTy) || - cvtReordersRegisters(srcTy, dstTy); - }; ModuleOp mod = getOperation(); triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - nvidiaShortCutFn); + isMmaToDotShortcut); triton::gpu::decomposeBlockedToDotLayoutConversion(mod); mlir::RewritePatternSet patterns(&getContext()); diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index d4c15bbad03f..fd65233e5c6b 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -41,9 +41,9 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); } - DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, - ArrayRef warps) { - auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, {1, 0}); + DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef warps, + ArrayRef order) { + auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order); return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); } @@ -301,19 +301,6 @@ TEST_F(LinearLayoutConversionsTest, Blocked4D) { {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); } -TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { - EXPECT_EQ(toLinearLayout({16, 16}, - mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), - LinearLayout( - { - {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, - {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, - {S("warp"), {}}, - {S("block"), {}}, - }, - {S("dim0"), S("dim1")})); -} - TEST_F(LinearLayoutConversionsTest, MMAv2_32x32) { EXPECT_EQ(toLinearLayout({32, 32}, mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), @@ -515,7 +502,7 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { - EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1})), + EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), LinearLayout( { {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, @@ -524,7 +511,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1})), + EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), LinearLayout( { {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, @@ -537,7 +524,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { EXPECT_EQ( - toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1})), + toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), @@ -547,7 +534,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1})), + EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), @@ -567,7 +554,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1})), + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index 897172fd6d34..f006447002ef 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -747,39 +747,6 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); } -TEST_F(LinearLayoutTest, Resize) { - auto init = LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}}}, - {S("in1"), {{1, 0}, {2, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")}); - EXPECT_EQ(init.resize(S("in0"), 8), - LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}, {0, 0}}}, - {S("in1"), {{1, 0}, {2, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")})); - EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}}}, - {S("in1"), {{1, 0}, {2, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")})); - EXPECT_EQ(init.resize(S("in1"), 8), - LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}}}, - {S("in1"), {{1, 0}, {2, 0}, {0, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")})); -} - } // anonymous namespace } // namespace mlir::triton