diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 0952ab984cc9..b000a3129912 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -1f20eee6dc367bd202895e3eedb03974a628ef16 +b5cc222d7429fe6f18c787f633d5262fac2e676f diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index df6029db0de2..cb3e3d292efa 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -212,7 +212,11 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); bool atomicNeedsSharedMemory(Value result); -bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT); + +bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); + +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); // Return true if the src and dst layout match. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, diff --git a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h index 8c7ab9831667..22c8f9c8a330 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -18,6 +18,14 @@ namespace gpu { SmallVector reorderValues(const SmallVector &values, Type inType, Type ouType); +SmallVector unpackI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter); + +SmallVector packI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter); + Type getElementType(Value value); class MultipleOperandsRange @@ -179,8 +187,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { for (auto operand : adaptor.getOperands()) { auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); - subOperands = unpackI32s(subOperands, argTy, rewriter, loc, - this->getTypeConverter()); + subOperands = unpackI32(subOperands, argTy, rewriter, loc, + this->getTypeConverter()); allOperands.resize(subOperands.size()); for (auto v : llvm::enumerate(subOperands)) allOperands[v.index()].push_back(v.value()); @@ -207,7 +215,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { } resultVals = maybeDeduplicate(op, resultVals); resultVals = - packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); + packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); Value view = packLLElements(loc, this->getTypeConverter(), resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 56a82d7cc0fb..29b8865c03ae 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1388,67 +1388,6 @@ inline Value getStructFromSharedMemoryObject(Location loc, return llvmStruct; } -// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer -// instructions to pack & unpack sub-word integers. A workaround is to -// store the results of tensors with dot operand encodings in i32 to -// facilitate instructions such as `ldmatrix`. -// -// TODO: Confirm if the problem is still there. -inline bool requiresI32Conversion(Type type) { - auto tensorTy = dyn_cast(type); - if (!tensorTy) - return false; - auto dotOpEnc = dyn_cast(tensorTy.getEncoding()); - if (!dotOpEnc) - return false; - auto parent = dyn_cast(dotOpEnc.getParent()); - if (!(parent && parent.getVersionMajor() < 3)) - return false; - return true; -} - -inline SmallVector packI32s(const SmallVector &inValues, - Type type, RewriterBase &rewriter, - Location loc, - const LLVMTypeConverter *typeConverter) { - if (!requiresI32Conversion(type)) - return inValues; - Type eltTy = - typeConverter->convertType(cast(type).getElementType()); - - SmallVector outValues; - int vecWidth = 32 / eltTy.getIntOrFloatBitWidth(); - auto vecTy = vec_ty(eltTy, vecWidth); - for (int i = 0; i < inValues.size(); i += vecWidth) { - Value vec = undef(vecTy); - for (int j = 0; j < vecWidth; j++) { - vec = insert_element(vec, inValues[i + j], i32_val(j)); - } - outValues.push_back(bitcast(vec, i32_ty)); - } - return outValues; -} - -inline SmallVector unpackI32s(const SmallVector &inValues, - Type type, RewriterBase &rewriter, - Location loc, - const LLVMTypeConverter *typeConverter) { - if (!requiresI32Conversion(type)) - return inValues; - Type eltTy = - typeConverter->convertType(cast(type).getElementType()); - - SmallVector outValues; - for (auto v : inValues) { - auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth()); - auto vec = bitcast(v, vecTy); - for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) { - outValues.push_back(extract_element(vec, i32_val(i))); - } - } - return outValues; -} - inline SmallVector unpackLLElements(Location loc, Value llvmStruct, RewriterBase &rewriter) { assert(bool(llvmStruct) && "can not unpack null values"); diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index c39c408d9330..283dd9165918 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -727,10 +727,6 @@ def TT_ReduceOp: TT_Op<"reduce", llvm::SmallVector getInputTypes(); llvm::SmallVector getElementTypes(); unsigned getNumOperands(); - - // Returns the CombineOp iff this ReduceOp's region contains only - // one CombineOp other than the return, or nullptr if not applicable. - ::mlir::Operation *getSingleCombiner(); }]; } diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 47e3fca79bb1..c728cfbb32cf 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -679,13 +679,6 @@ class LinearLayout { // (i.e. every input bit affects the output). llvm::MapVector getFreeVariableMasks() const; - // Increase an input dimension without affecting the output dimension. The - // added free variables are mapped to 0, ensuring that the new input - // dimensions correspond directly to the existing output space. The function - // errors out if `newInDimSize` is less than the current size or the new size - // is not a power of 2. - LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const; - std::string toString() const; friend bool operator==(LinearLayout lhs, LinearLayout rhs); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 665b97aeebba..276a6e7004df 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - assert(cvtNeedsSharedMemory(srcTy, dstTy)); + assert(!isMfmaToDotShortcut(srcTy, dstTy)); // FIXME This is NOT entirely correct // This should be getElemOrder, but we don't have such a method diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index aa9f8b01eae1..4915d7b1acda 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } -bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { auto blockedLayout = dyn_cast(srcTy.getEncoding()); auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); if (blockedLayout == nullptr || dotOperandLayout == nullptr) @@ -605,6 +605,22 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { return matrixDimsCompatible && bDimCompatible; } +bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + auto mfmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (mfmaLayout == nullptr || dotOperandLayout == nullptr) + return false; + // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is + // improved. In addition, we can enable this shortcut for regular MFMA + // layout when opIdx == 1. + return mfmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && + dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] && + dotOperandLayout.getParent() == mfmaLayout && + (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && + (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); +} + // For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy) { @@ -639,46 +655,8 @@ std::optional minimalCvtLayout(RankedTensorType srcTy, toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); if (!(srcLayout.has_value() && dstLayout.has_value())) return std::nullopt; - StringAttr kRegister = StringAttr::get(ctx, "register"); - StringAttr kLane = StringAttr::get(ctx, "lane"); - StringAttr kWarp = StringAttr::get(ctx, "warp"); - StringAttr kBlock = StringAttr::get(ctx, "block"); - auto numSrcRegs = srcLayout->getInDimSize(kRegister); - auto numDstRegs = dstLayout->getInDimSize(kRegister); - // The `invertAndCompose` function will generate a layout that is injective - // by assigning new output dimensions to free variables. For instance, - // consider a scenario where `srcLayout` has a free variable in the lane - // dimension, while `dstLayout` has two free variables in the lane - // dimension and also a larger number of registers. - // The injective form of `srcLayout` will add only a single additional row - // to the transformation matrix, whereas the injective form of `dstLayout` - // will add two additional rows. This discrepancy causes misleading results - // because the matrices end up with a different number of rows. - // - // Take `dstLayout ⋅ srcLayout^-1` as an example: - // - // - `injective(dstLayout)`: [n, m] → [n + 2, m] - // - `injective(srcLayout)`: [n, m] → [n + 1, m] - // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1] - // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n + - // 1] → [n + 2, n + 1] - // - // Here, the `(n + 1)`-th row added by `dstLayout` represents the free - // variable in registers, and the `(n + 2)`-th row represents the free - // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout` - // represents the free variable in lanes. As a result, the `(n + 1)`-th row - // in two layouts do not correspond to the same free variable. - // - // To address this issue, we pad the free variables in `srcLayout` and - // `dstLayout` to ensure they have the same number of registers. This - // guarantees that the resulting matrices have the same number of rows, - // ensuring consistency in the composition process. - auto numRegs = std::max(numSrcRegs, numDstRegs); - auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs); - auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs); // comp describes the layout function to create dst from src. - LinearLayout comp = - dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs); + LinearLayout comp = dstLayout->invertAndCompose(*srcLayout); // We try to quotient by the largest subspace first auto dims = SmallVector{"block", "warp", "lane", "register"}; for (auto dim : dims) { @@ -715,14 +693,15 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { } bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { - // TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and - // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout - // checks. + // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`, + // `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully + // subsumed by the linear-layout checks. // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !matchMmaV3AndDotOperandLayout(srcTy, dstTy); + !isMmaToDotShortcut(srcTy, dstTy) && + !isMfmaToDotShortcut(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { @@ -732,6 +711,20 @@ bool atomicNeedsSharedMemory(Value value) { return true; } +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) + return true; + // dot_op = #mma + // when #mma = MmaEncoding + auto mmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 && + mmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && + dotOperandLayout.getParent() == mmaLayout && + !srcTy.getElementType().isF32(); +} + namespace { /// A data structure similar to SetVector but maintains diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 65ee8cc0023e..a18b2cbc308c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -288,90 +288,62 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return rewriter.notifyMatchFailure( op, "NYI. srcTy and/or dstTy don't implement LLs yet"); } - LinearLayout srcLayout = - *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - LinearLayout dstLayout = - *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); - - StringAttr kBlock = str_attr("block"); - StringAttr kWarp = str_attr("warp"); - StringAttr kLane = str_attr("lane"); - StringAttr kRegister = str_attr("register"); assert(to_vector(conversion->getInDimNames()) == to_vector(conversion->getOutDimNames())); auto dims = conversion->getInDimNames(); - if (llvm::is_contained(dims, kBlock)) { + if (llvm::is_contained(dims, str_attr("block"))) { // Case 1: Transfer between values in different CTAs. // This requires moving values through distributed shared memory. return rewriter.notifyMatchFailure( op, "NYI: Transfer between different CTAs"); - } else if (llvm::is_contained(dims, kWarp)) { + } else if (llvm::is_contained(dims, str_attr("warp"))) { // Case 2: Transfer between values in the same CTA, in which case we move // values through shared memory. + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, kLane)) { + } else if (llvm::is_contained(dims, str_attr("lane"))) { // Case 3. Transfer between values in the same warp, in which case we try // to move values using warp shuffles, though if the pattern is // complicated enough we may fall back to using shared memory // TODO(Keren): implement warp shuffle instead of using the general // approach that uses shared memory + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, kRegister) || - dstLayout.getInDimSize(kRegister) != - srcLayout.getInDimSize(kRegister)) { + } else if (llvm::is_contained(dims, str_attr("register"))) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). - return transferWithinThread( - op, dstLayout.getFreeVariableMasks()[kRegister], - dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter); + return transferWithinThread(op, *conversion, adaptor, rewriter); } else { - // Cast 5. The two layouts are equivalent. We should probably remove - // these in RemoveLayoutConversion. - auto dstCvt = requiresI32Conversion(dstTy); - auto srcCvt = requiresI32Conversion(srcTy); - if (dstCvt || srcCvt) { - auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter); - inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(), - getTypeConverter()); - inVals = - packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter()); - auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals, - rewriter, op.getType()); - rewriter.replaceOp(op, res); - } else { - rewriter.replaceOp(op, adaptor.getSrc()); - } + // The two layouts are equivalent. We should probably remove these in + // RemoveLayoutConversion. + rewriter.replaceOp(op, adaptor.getSrc()); return success(); } } LogicalResult - transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs, - const LinearLayout &conversion, OpAdaptor adaptor, + transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); StringAttr kRegister = str_attr("register"); assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); - auto srcTy = op.getSrc().getType(); - auto dstTy = op.getType(); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); - SmallVector outVals(numRegs); - for (int i = 0; i < numRegs; i++) { - // Remove free masks from the register index - // For example, if idx = 0b00111, and masks = 0b00100, then we get - // 0b00011. It means that register 7 (0b111) has the same value as - // register 3 (0b011). - auto idx = i & (~regMasks); - auto srcIdx = conversion.hasInDim(kRegister) - ? conversion.apply({{kRegister, idx}}).begin()->second - : idx; + SmallVector outVals; + outVals.resize(conversion.getInDimSize(kRegister)); + for (int i = 0; i < conversion.getInDimSize(kRegister); i++) { + auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; outVals[i] = inVals[srcIdx]; } - outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter()); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); @@ -403,6 +375,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (auto dotOperand = dyn_cast(layout)) { if (auto nvidiaMma = dyn_cast(dotOperand.getParent())) { + if (product(getCTAsPerCGA(nvidiaMma)) > 1) { + return false; + } if (useLegacyMMAConversion) { return false; } @@ -412,7 +387,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64; return largeKWidth && nvidiaMma.isAmpere(); } - return false; } if (isa(layout)) { return true; @@ -454,7 +428,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion inVals[it.index()] = ptrtoint(llvmElemTy, it.value()); } } - inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); // Pretty sure this is the identity function ATM // It'd be better to simply call `quotient({kBlock})` and @@ -474,7 +447,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } - outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter()); + // FIXME [Dot LL] + // We know it's just for largeKWidth case in Ampere + // In this case, we need to pack the outputs into i32 + if (isa(dstTy.getEncoding())) { + auto concat = [&](Value a, Value b) { + return or_(zext(i32_ty, bitcast(a, i16_ty)), + shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); + }; + + SmallVector outVals32(outVals.size() / 2); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); + } + outVals = outVals32; + } + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 470e8b32b540..8ee166866974 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -103,6 +103,51 @@ SmallVector reorderValues(const SmallVector &values, Type inType, llvm_unreachable("unimplemented code path"); } +SmallVector unpackI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return inValues; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return inValues; + SmallVector outValues; + for (auto v : inValues) { + // cast i32 to appropriate eltType vector and extract elements + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecType); + for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + +SmallVector packI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return inValues; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return inValues; + SmallVector outValues; + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); + auto vecType = vec_ty(eltType, vecWidth); + for (int i = 0; i < inValues.size(); i += vecWidth) { + Value vec = undef(vecType); + for (int j = 0; j < vecWidth; j++) { + vec = insert_element(vec, inValues[i + j], i32_val(j)); + } + outValues.push_back(bitcast(vec, i32_ty)); + } + return outValues; +} + int getNumElementsPerThreads(Type type, const LLVMTypeConverter *typeConverter) { int numElemsPerThread = 1; @@ -455,7 +500,7 @@ struct ElementwiseInlineAsmOpConversion auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); unpackedOperands.push_back( - unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter())); + unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter())); } int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), @@ -515,11 +560,10 @@ struct ElementwiseInlineAsmOpConversion unpackedResults[i], /*inType=*/op->getOperand(0).getType(), /*ouType=*/op->getResult(i).getType()); } - auto dstTy = op->getResult(i).getType(); - unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc, - getTypeConverter()); - outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i], - rewriter, op->getResult(i).getType())); + auto packed = packI32(unpackedResults[i], op->getResult(i).getType(), + rewriter, loc, getTypeConverter()); + outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter, + op->getResult(i).getType())); } rewriter.replaceOp(op, outs); diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index e2ed0228de8d..1a0c115a9ecf 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -184,7 +184,42 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { SmallVector outVals = loadSharedToDistributed( dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo); - outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter); + // FIXME [Dot LL] + // Ampere case + // In this case, we need to pack the outputs into i32 + if (auto dotOp = dyn_cast(dstTy.getEncoding())) { + if (auto parent = dyn_cast(dotOp.getParent())) { + if (parent.isAmpere()) { + if (elemLlvmTy.isInteger(8)) { + auto concat = [&](Value a1, Value a2, Value a3, Value a4) { + return or_( + or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))), + or_(shl(zext(i32_ty, a3), i32_val(16)), + shl(zext(i32_ty, a4), i32_val(24)))); + }; + SmallVector outVals32(outVals.size() / 4); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1], + outVals[4 * i + 2], outVals[4 * i + 3]); + } + outVals = outVals32; + } else { + assert(elemLlvmTy.isBF16() && "Unexpected element type"); + auto concat = [&](Value a, Value b) { + return or_(zext(i32_ty, bitcast(a, i16_ty)), + shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); + }; + + SmallVector outVals32(outVals.size() / 2); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); + } + outVals = outVals32; + } + } + } + } + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); rewriter.replaceOp(op, result); diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 06e75ee18d59..34fb8995430f 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -56,19 +56,20 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // This will create newArg, and map(origArg, newArg) addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, - Location loc) -> Value { + Location loc) -> std::optional { llvm_unreachable("Argument rematerialization should not happen in Triton " "-> TritonGPU conversion"); - return {}; + return std::nullopt; }); // If the origValue still has live user(s), use this to // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, Location loc) -> Value { + ValueRange inputs, + Location loc) -> std::optional { llvm_unreachable("Source rematerialization should not happen in Triton -> " "TritonGPU Conversion"); - return {}; + return std::nullopt; }); // This will be called when (desiredType != newOperandType) @@ -78,7 +79,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, ValueRange inputs, Location loc) { auto cast = builder.create(loc, tensorType, inputs); - return cast.getResult(); + return std::optional(cast.getResult()); }); } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index e77e2d5c8691..ffea5f3c67a6 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -503,22 +503,6 @@ llvm::SmallVector ReduceOp::getElementTypes() { return getElementTypesImpl(this->getOperands()); } -::mlir::Operation *ReduceOp::getSingleCombiner() { - if (getNumOperands() != 1 || getNumResults() != 1) - return nullptr; - Block *block = &(*getCombineOp().begin()); - Operation *yield = block->getTerminator(); - Operation *reduceOp = yield->getOperand(0).getDefiningOp(); - if (!reduceOp || reduceOp->getNumOperands() != 2 || - reduceOp->getNumResults() != 1) - return nullptr; - if (reduceOp->getOperand(0) != block->getArgument(0) || - reduceOp->getOperand(1) != block->getArgument(1)) - return nullptr; - - return reduceOp; -} - unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } //-- ScanOp -- diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 28b871983bfe..5a2fcecfa949 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -17,7 +17,7 @@ def CombineDotAddIPattern : Pat< [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm), + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), @@ -29,7 +29,7 @@ def CombineDotAddIRevPattern : Pat< [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm), + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 6978ccfb2553..56af4eaef8b9 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -551,8 +551,8 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { } std::optional -mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, - ArrayRef shape) { +dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { // Current linear layout conversion for dot operand is only necessary to // enable LDS bypass for operand B in the MFMA dot path. To achieve @@ -895,7 +895,7 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); auto order = dot.getCTAOrder(); - assert(order[0] == rank - 1 && order[1] == rank - 2); + assert(order[0] == 1 && order[1] == 0); ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); @@ -903,11 +903,13 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, std::optional DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { - auto parent = getParent(); - if (auto mfmaLayout = llvm::dyn_cast(parent)) { - return mfmaDotToLinearLayout(*this, shape); - } else if (auto mma = mlir::dyn_cast(parent)) { - if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) { + if (auto mfmaLayout = llvm::dyn_cast(getParent())) { + return dotOperandMfmaToLinearLayout(*this, shape); + } else if (auto mma = mlir::dyn_cast(getParent())) { + // FIXME [Dot LL] + // Do this unconditionally + auto largeKWidth = getKWidth() == 8; + if (mma.isAmpere() && largeKWidth) { return ampereDotToLinearLayout(shape, *this); } } diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 9f3d8fff491b..6d8279795209 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -290,7 +290,7 @@ struct MMAV3UseRegOperand dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); - if (!matchMmaV3AndDotOperandLayout(srcTy, newTy)) + if (!isMmaToDotShortcut(srcTy, newTy)) return failure(); Value newOperand = diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 4319d1f086dd..bf017f8c6463 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -1016,21 +1016,6 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const { return true; } -LinearLayout LinearLayout::resize(StringAttr inDim, - int32_t newInDimSize) const { - BasesT bases = getBases(); - assert(bases.contains(inDim) && "inDim not in layout"); - assert(llvm::isPowerOf2_32(newInDimSize) && - "newInDimSize must be a power of 2"); - assert(newInDimSize >= getInDimSize(inDim) && - "newInDimSize must be >= old size"); - auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim); - for (int i = 0; i < numFreeVariables; i++) { - bases[inDim].push_back(std::vector(getNumOutDims(), 0)); - } - return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames())); -} - std::string LinearLayout::toString() const { // Start with a newline because we print out a bulleted list; it doesn't // make sense for the first line of this list to be on the same line as diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 208f6b80bfe5..a0719c974f9c 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -6,7 +6,7 @@ #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> #B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index df4f5ab01feb..2054853b30c1 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -5,7 +5,7 @@ #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index a2f713faaf18..ef6733845721 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -62,151 +62,3 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } - -// ----- - -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_f16x2 - tt.func @atomic_add_f16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { - %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> - %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> - %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> - // CHECK: llvm.cond_br - // CHECK-NOT: rocdl.update.dpp - // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> - // CHECK-NOT: rocdl.update.dpp - %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> - tt.return - } -} - -// ----- - -#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_bf16x2 - tt.func @atomic_add_bf16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { - %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> - %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> - %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> - // CHECK: llvm.cond_br - // CHECK-NOT: rocdl.update.dpp - // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> - // CHECK-NOT: rocdl.update.dpp - %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> - tt.return - } -} - -// ----- - -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_f16_dpp - tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { - %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> - %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> - %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> - // CHECK: llvm.cond_br - // CHECK: rocdl.update.dpp - // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> - // CHECK: rocdl.update.dpp - %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> - tt.return - } -} - -// ----- - -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_bf16_dpp - tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { - %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> - %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> - %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> - // CHECK: llvm.cond_br - // CHECK: rocdl.update.dpp - // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> - // CHECK: rocdl.update.dpp - %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> - tt.return - } -} - -// ----- - -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: reduce_dpp_max - tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) { - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 280, 15, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK-NEXT: rocdl.update.dpp - // CHECK-SAME: with 276, 15, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK-NEXT: rocdl.update.dpp - // CHECK-SAME: with 274, 15, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK-NEXT: rocdl.update.dpp - // CHECK-SAME: with 273, 15, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK-NEXT: rocdl.update.dpp - // CHECK-SAME: with 322, 10, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK-NEXT: rocdl.update.dpp - // CHECK-SAME: with 323, 15, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK: llvm.amdgcn.readlane - %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<64xf32, #blocked3>) -> f32 - tt.return - } -} - -// ----- - -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: reduce_xor_max - tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) { - // CHECK: rocdl.ds_swizzle - // CHECK: llvm.intr.maxnum - - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 280, 15, 12, false : i32 - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 264, 15, 3, false : i32 - // CHECK: llvm.intr.maxnum - - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 276, 15, 10, false : i32 - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 260, 15, 5, false : i32 - // CHECK: llvm.intr.maxnum - - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 78, 15, 15, false : i32 - // CHECK: llvm.intr.maxnum - - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 177, 15, 15, false : i32 - %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<32xf32, #blocked4>) -> f32 - tt.return - } -} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 325c425a2277..34573f7739b8 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -821,110 +821,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: convert_layout_mmav2_dot_reg - tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> - tt.return - } -} - -// ----- - -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: convert_layout_mmav2_dot_reg - tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<1x16xf16, #mma>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1> - tt.return - } -} - -// ----- - -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#slice = #triton_gpu.slice<{dim = 0, parent = #mma}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_slice_mmav2_blocked_reg - tt.func @convert_layout_slice_mmav2_blocked_reg(%arg0: tensor<1xf16, #slice>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_0 - tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_1 - tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_2 - tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_3 - tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> - tt.return - } -} - -// ----- - #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { @@ -949,80 +845,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: convert_layout_mmav2_dot_reg - tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_0 - tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_1 - tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_2 - tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> - tt.return - } -} - -// ----- - -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: convert_layout_mmav3_mmav3_3 - tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { - // CHECK-NOT: st.shared - // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> - tt.return - } -} - -// ----- - #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 25897f2a9378..4fb418e3811b 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -42,7 +42,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %1 = arith.muli %0, %c1024_i32 : i32 %sub = arith.subi %1, %c128_i32 : i32 %cmp = arith.cmpi sgt, %sub, %c0_i32 : i32 - llvm.intr.assume %cmp : i1 + "llvm.intr.assume"(%cmp) : (i1) -> () %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> // CHECK: %[[offset:.*]] = arith.addi diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 5dfd0f2a5f4c..686e5a24e8dd 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -460,6 +460,429 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } } +// ----- +// This test ensures that loads will not be moved across `for` loops. + +// CHECK-LABEL: tt.func public @_attn_bwd +// CHECK: tt.load +// CHECK: tt.load +// CHECK: scf.for +// CHECK: } +// CHECK: scf.for +// CHECK: } +// Moved before the independent `tt.store` ops but not before the `for` ops. +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.store +// CHECK: tt.store +// CHECK: scf.for +// CHECK: } +// CHECK: scf.for +// CHECK: } +// CHECK: tt.store + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#mma1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> +#shared3 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @_attn_bwd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma> + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %cst_2 = arith.constant dense<0.693147182> : tensor<128x64xf32, #mma> + %0 = tt.get_program_id z : i32 + %1 = arith.muli %0, %arg14 : i32 + %2 = arith.extsi %1 : i32 to i64 + %3 = arith.remsi %0, %arg13 : i32 + %4 = arith.muli %arg11, %3 : i32 + %5 = arith.divsi %0, %arg13 : i32 + %6 = arith.muli %arg10, %5 : i32 + %7 = arith.addi %4, %6 : i32 + %8 = arith.extsi %7 : i32 to i64 + %9 = tt.get_program_id x : i32 + %10 = tt.addptr %arg0, %8 : !tt.ptr, i64 + %11 = tt.addptr %arg1, %8 : !tt.ptr, i64 + %12 = tt.addptr %arg2, %8 : !tt.ptr, i64 + %13 = tt.addptr %arg4, %8 : !tt.ptr, i64 + %14 = tt.addptr %arg5, %8 : !tt.ptr, i64 + %15 = tt.addptr %arg6, %8 : !tt.ptr, i64 + %16 = tt.addptr %arg7, %8 : !tt.ptr, i64 + %17 = tt.addptr %arg8, %2 : !tt.ptr, i64 + %18 = tt.addptr %arg9, %2 : !tt.ptr, i64 + %19 = arith.muli %9, %c128_i32 : i32 + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %25 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %26 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %27 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %28 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %29 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %30 = arith.addi %25, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %31 = arith.addi %26, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %32 = arith.addi %27, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %33 = arith.addi %28, %23 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %34 = arith.addi %29, %24 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %35 = tt.expand_dims %30 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> + %36 = tt.expand_dims %31 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %37 = tt.expand_dims %32 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> + %38 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #mma> + %39 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #blocked> + %40 = arith.muli %35, %38 : tensor<128x1xi32, #mma> + %41 = arith.muli %36, %39 : tensor<128x1xi32, #blocked> + %42 = tt.splat %11 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %43 = tt.addptr %42, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %45 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %46 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %47 = tt.expand_dims %44 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> + %48 = tt.expand_dims %45 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %49 = tt.expand_dims %46 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %50 = tt.broadcast %43 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %51 = tt.broadcast %47 : tensor<1x64xi32, #mma> -> tensor<128x64xi32, #mma> + %52 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> + %53 = tt.addptr %50, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %54 = tt.load %53 : tensor<128x64x!tt.ptr, #blocked> + %55 = tt.splat %12 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %56 = tt.addptr %55, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %57 = tt.broadcast %56 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %58 = tt.addptr %57, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %59 = tt.load %58 : tensor<128x64x!tt.ptr, #blocked> + %60 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %61 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %62 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %63 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %64 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %65 = arith.addi %63, %60 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %66 = arith.addi %64, %62 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %67 = tt.expand_dims %65 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi32, #blocked2> + %68 = tt.splat %arg12 : i32 -> tensor<1x16xi32, #blocked2> + %69 = arith.muli %67, %68 : tensor<1x16xi32, #blocked2> + %70 = tt.splat %10 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> + %71 = tt.addptr %70, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> + %72 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %73 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %74 = tt.expand_dims %72 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1xi32, #blocked2> + %75 = tt.expand_dims %73 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi32, #blocked3> + %76 = tt.broadcast %71 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> + %77 = tt.broadcast %74 : tensor<64x1xi32, #blocked2> -> tensor<64x16xi32, #blocked2> + %78 = tt.addptr %76, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %79 = tt.expand_dims %66 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> + %80 = tt.splat %arg12 : i32 -> tensor<16x1xi32, #blocked1> + %81 = arith.muli %79, %80 : tensor<16x1xi32, #blocked1> + %82 = tt.splat %13 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked1> + %83 = tt.addptr %82, %81 : tensor<16x1x!tt.ptr, #blocked1>, tensor<16x1xi32, #blocked1> + %84 = tt.broadcast %83 : tensor<16x1x!tt.ptr, #blocked1> -> tensor<16x64x!tt.ptr, #blocked1> + %85 = tt.broadcast %49 : tensor<1x64xi32, #blocked1> -> tensor<16x64xi32, #blocked1> + %86 = tt.addptr %84, %85 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> + %87 = tt.splat %17 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %88 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> + %89 = tt.splat %18 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %90 = arith.muli %arg12, %c16_i32 : i32 + %91 = tt.splat %90 : i32 -> tensor<64x16xi32, #blocked2> + %92 = tt.splat %90 : i32 -> tensor<16x64xi32, #blocked1> + %93:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %cst_1, %arg18 = %19, %arg19 = %78, %arg20 = %86) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1>) : i32 { + %206 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> + %207 = tt.splat %arg18 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %208 = arith.addi %207, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %209 = tt.addptr %87, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %210 = tt.load %209 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %213 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %215 = tt.dot %212, %214, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> + %217 = tt.broadcast %216 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> + %218 = arith.subf %215, %217 : tensor<128x16xf32, #mma1> + %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> + %220 = tt.expand_dims %208 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> + %221 = tt.broadcast %220 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> + %222 = arith.cmpi sge, %221, %88 : tensor<128x16xi32, #mma1> + %223 = arith.select %222, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> + %224 = tt.load %arg20 : tensor<16x64x!tt.ptr, #blocked1> + %225 = arith.truncf %223 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %226 = triton_gpu.local_alloc %225 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> + %227 = triton_gpu.local_load %226 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %228 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> + %229 = triton_gpu.local_load %228 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %230 = tt.dot %227, %229, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %231 = tt.addptr %89, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %232 = tt.load %231 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %233 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %234 = tt.trans %233 {order = array} : !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %235 = triton_gpu.local_load %234 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %236 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %237 = triton_gpu.local_load %236 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %238 = tt.dot %237, %235, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %239 = tt.expand_dims %232 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> + %240 = tt.broadcast %239 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> + %241 = arith.subf %238, %240 : tensor<128x16xf32, #mma1> + %242 = arith.mulf %223, %241 : tensor<128x16xf32, #mma1> + %243 = arith.truncf %242 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %244 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> + %245 = tt.trans %244 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> + %246 = triton_gpu.local_load %245 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %247 = triton_gpu.local_alloc %243 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> + %248 = triton_gpu.local_load %247 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %249 = tt.dot %248, %246, %arg17 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %250 = arith.addi %arg18, %c16_i32 : i32 + %251 = tt.addptr %arg19, %91 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %252 = tt.addptr %arg20, %92 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> + scf.yield %230, %249, %250, %251, %252 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1> + } + %94 = arith.addi %19, %c128_i32 : i32 + %95 = arith.subi %arg14, %94 : i32 + %96 = arith.divsi %95, %c32_i32 : i32 + %97 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %98 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %99 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %100 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %101 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %102 = arith.addi %100, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %103 = arith.addi %101, %99 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %104 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %105 = tt.splat %arg12 : i32 -> tensor<1x32xi32, #blocked3> + %106 = arith.muli %104, %105 : tensor<1x32xi32, #blocked3> + %107 = tt.splat %10 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> + %108 = tt.addptr %107, %106 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> + %109 = tt.broadcast %108 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> + %110 = tt.broadcast %75 : tensor<64x1xi32, #blocked3> -> tensor<64x32xi32, #blocked3> + %111 = tt.addptr %109, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %112 = tt.expand_dims %103 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %113 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked> + %114 = arith.muli %112, %113 : tensor<32x1xi32, #blocked> + %115 = tt.splat %13 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %116 = tt.addptr %115, %114 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %117 = tt.broadcast %116 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x64x!tt.ptr, #blocked> + %118 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %119 = tt.addptr %117, %118 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + %120 = tt.splat %17 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %121 = tt.splat %18 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %122 = arith.muli %arg12, %c32_i32 : i32 + %123 = tt.splat %122 : i32 -> tensor<64x32xi32, #blocked3> + %124 = tt.splat %122 : i32 -> tensor<32x64xi32, #blocked> + %125:5 = scf.for %arg15 = %c0_i32 to %96 step %c1_i32 iter_args(%arg16 = %93#0, %arg17 = %93#1, %arg18 = %94, %arg19 = %111, %arg20 = %119) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked>) : i32 { + %206 = tt.load %arg19 : tensor<64x32x!tt.ptr, #blocked3> + %207 = tt.splat %arg18 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %208 = arith.addi %207, %98 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %209 = tt.addptr %120, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %210 = tt.load %209 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %213 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %215 = tt.dot %212, %214, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> + %217 = tt.broadcast %216 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> + %218 = arith.subf %215, %217 : tensor<128x32xf32, #mma> + %219 = math.exp2 %218 : tensor<128x32xf32, #mma> + %220 = tt.load %arg20 : tensor<32x64x!tt.ptr, #blocked> + %221 = arith.truncf %219 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> + %222 = triton_gpu.convert_layout %221 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %223 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> + %224 = triton_gpu.local_load %223 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %225 = tt.dot %222, %224, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %226 = tt.addptr %121, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %227 = tt.load %226 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %228 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> + %229 = tt.trans %228 {order = array} : !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> + %230 = triton_gpu.local_load %229 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %231 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %232 = triton_gpu.local_load %231 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %233 = tt.dot %232, %230, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %234 = tt.expand_dims %227 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> + %235 = tt.broadcast %234 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> + %236 = arith.subf %233, %235 : tensor<128x32xf32, #mma> + %237 = arith.mulf %219, %236 : tensor<128x32xf32, #mma> + %238 = arith.truncf %237 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> + %239 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> + %240 = tt.trans %239 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> + %241 = triton_gpu.local_load %240 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %242 = triton_gpu.convert_layout %238 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %243 = tt.dot %242, %241, %arg17 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %244 = arith.addi %arg18, %c32_i32 : i32 + %245 = tt.addptr %arg19, %123 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %246 = tt.addptr %arg20, %124 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + scf.yield %225, %243, %244, %245, %246 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked> + } + %126 = tt.splat %16 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> + %127 = tt.addptr %126, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> + %128 = tt.broadcast %127 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> + %129 = tt.addptr %128, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> + %130 = arith.truncf %125#0 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + tt.store %129, %130 : tensor<128x64x!tt.ptr, #mma> + %131 = tt.splat %arg3 : f32 -> tensor<128x64xf32, #mma> + %132 = arith.mulf %125#1, %131 : tensor<128x64xf32, #mma> + %133 = tt.splat %15 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> + %134 = tt.addptr %133, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> + %135 = tt.broadcast %134 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> + %136 = tt.addptr %135, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> + %137 = arith.truncf %132 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + tt.store %136, %137 : tensor<128x64x!tt.ptr, #mma> + %138 = tt.splat %10 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %139 = tt.addptr %138, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %140 = tt.broadcast %139 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %141 = tt.addptr %140, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %142 = tt.load %141 : tensor<128x64x!tt.ptr, #blocked> + %143 = tt.splat %13 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %144 = tt.addptr %143, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %145 = tt.broadcast %144 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %146 = tt.addptr %145, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %147 = tt.load %146 : tensor<128x64x!tt.ptr, #blocked> + %148 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %149 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %150 = tt.addptr %148, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %151 = tt.addptr %149, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %152 = tt.load %150 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %153 = tt.load %151 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %154 = tt.expand_dims %152 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> + %155 = tt.expand_dims %153 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %156 = tt.splat %11 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> + %157 = tt.addptr %156, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> + %158 = tt.broadcast %157 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> + %159 = tt.addptr %158, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %160 = tt.splat %12 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> + %161 = tt.addptr %160, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> + %162 = tt.broadcast %161 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> + %163 = tt.addptr %162, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %164 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %165 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %166 = tt.addptr %164, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %167 = tt.addptr %165, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %168 = tt.load %166 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %169 = tt.load %167 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %170 = tt.broadcast %154 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> + %171 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> + %172 = tt.expand_dims %168 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> + %173 = tt.broadcast %172 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> + %174 = arith.muli %arg12, %c16_i32 : i32 + %175 = tt.splat %174 : i32 -> tensor<64x16xi32, #blocked2> + %176 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + %177:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %19, %arg18 = %159, %arg19 = %163, %arg20 = %c-1_i32) -> (tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32) : i32 { + %206 = arith.addi %arg20, %c1_i32 : i32 + %207 = arith.cmpi slt, %206, %c1_i32 : i32 + %208 = arith.select %207, %206, %c0_i32 : i32 + %209 = tt.load %arg18 : tensor<64x16x!tt.ptr, #blocked2> + %210 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> + %211 = triton_gpu.memdesc_subview %176[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %210, %211 : tensor<64x16xf16, #blocked2> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + %212 = triton_gpu.local_load %211 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %215 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %216 = triton_gpu.local_load %215 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %217 = tt.dot %214, %216, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %218 = arith.subf %217, %170 : tensor<128x16xf32, #mma1> + %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> + %220 = tt.splat %arg17 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %221 = arith.addi %220, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %222 = tt.expand_dims %221 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> + %223 = tt.broadcast %222 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> + %224 = arith.cmpi sge, %171, %223 : tensor<128x16xi32, #mma1> + %225 = arith.select %224, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> + %226 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %227 = triton_gpu.local_load %226 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %228 = tt.dot %227, %212, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %229 = arith.subf %228, %173 : tensor<128x16xf32, #mma1> + %230 = arith.mulf %225, %229 : tensor<128x16xf32, #mma1> + %231 = arith.truncf %230 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %232 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> + %233 = tt.trans %232 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> + %234 = triton_gpu.local_load %233 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %235 = triton_gpu.local_alloc %231 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> + %236 = triton_gpu.local_load %235 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %237 = tt.dot %236, %234, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %238 = arith.addi %arg17, %c16_i32 : i32 + %239 = tt.addptr %arg18, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %240 = tt.addptr %arg19, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + scf.yield %237, %238, %239, %240, %208 : tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32 + } + triton_gpu.local_dealloc %176 : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + %178 = arith.divsi %19, %c32_i32 : i32 + %179 = arith.muli %178, %c32_i32 : i32 + %180 = arith.subi %19, %179 : i32 + %181 = tt.splat %180 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %182 = arith.addi %181, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %183 = tt.expand_dims %182 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %184 = arith.muli %183, %105 : tensor<1x32xi32, #blocked3> + %185 = tt.splat %11 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> + %186 = tt.addptr %185, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> + %187 = tt.broadcast %186 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> + %188 = tt.addptr %187, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %189 = tt.splat %12 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> + %190 = tt.addptr %189, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> + %191 = tt.broadcast %190 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> + %192 = tt.addptr %191, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %193 = tt.broadcast %155 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> + %194 = tt.expand_dims %169 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %195 = tt.broadcast %194 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> + %196 = arith.muli %arg12, %c32_i32 : i32 + %197 = tt.splat %196 : i32 -> tensor<64x32xi32, #blocked3> + %198 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + %199:4 = scf.for %arg15 = %c0_i32 to %178 step %c1_i32 iter_args(%arg16 = %177#0, %arg17 = %188, %arg18 = %192, %arg19 = %c-1_i32) -> (tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32) : i32 { + %206 = arith.addi %arg19, %c1_i32 : i32 + %207 = arith.cmpi slt, %206, %c1_i32 : i32 + %208 = arith.select %207, %206, %c0_i32 : i32 + %209 = tt.load %arg17 : tensor<64x32x!tt.ptr, #blocked3> + %210 = tt.load %arg18 : tensor<64x32x!tt.ptr, #blocked3> + %211 = triton_gpu.memdesc_subview %198[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %210, %211 : tensor<64x32xf16, #blocked3> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + %212 = triton_gpu.local_load %211 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %215 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> + %216 = triton_gpu.local_load %215 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %217 = tt.dot %214, %216, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %218 = arith.subf %217, %193 : tensor<128x32xf32, #mma> + %219 = math.exp2 %218 : tensor<128x32xf32, #mma> + %220 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %221 = triton_gpu.local_load %220 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %222 = tt.dot %221, %212, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %223 = arith.subf %222, %195 : tensor<128x32xf32, #mma> + %224 = arith.mulf %219, %223 : tensor<128x32xf32, #mma> + %225 = arith.truncf %224 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> + %226 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> + %227 = tt.trans %226 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> + %228 = triton_gpu.local_load %227 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %229 = triton_gpu.convert_layout %225 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %230 = tt.dot %229, %228, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %231 = tt.addptr %arg17, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %232 = tt.addptr %arg18, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + scf.yield %230, %231, %232, %208 : tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32 + } + triton_gpu.local_dealloc %198 : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + %200 = tt.splat %14 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> + %201 = tt.addptr %200, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> + %202 = tt.broadcast %201 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> + %203 = tt.addptr %202, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> + %204 = arith.mulf %199#0, %cst_2 : tensor<128x64xf32, #mma> + %205 = arith.truncf %204 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + tt.store %203, %205 : tensor<128x64x!tt.ptr, #mma> + tt.return + } +} + // ----- // CHECK-LABEL: sink_convert_dealloc diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir deleted file mode 100644 index 5c173ffb4858..000000000000 --- a/test/TritonGPU/amd/amd-sched-2nd-load.mlir +++ /dev/null @@ -1,211 +0,0 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s - -// Check the logic of sched-2nd-load optimizations -// - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> - -// Category 1: Single dot with two loads, we make sure the optimization is applied when tile size is large enough -// The following tile sizes should apply the optimization -// 256x256x128 -// 256x256x64 -// The following tile sizes should NOT apply the optimization -// 256x64x128 -// 256x256x32 -// - -// Should apply: tile size 256x256x128 with single dot -// CHECK-LABEL: sink_2nd_load_256x256x128 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} - -// Should apply: tile size 256x256x64 with single dot -// CHECK-LABEL: sink_2nd_load_256x256x64 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} - -// Should NOT apply: tile size 256x64x128 with single dot -// CHECK-LABEL: sink_2nd_load_256x64x128 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x64xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> - tt.return - } -} - -// Should NOT apply: tile size 256x256x32 with single dot -// CHECK-LABEL: sink_2nd_load_256x256x32 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} - - -// Category 2: single dot with two loads and tile size is large enough (128x128x128). -// We make sure the move is legal. -// Should NOT apply: the 2nd load has a user before the dot -// CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: tt.store -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> - %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> - tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> - %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> - triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<128x128xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> - tt.return - } -} - - -// ----- - -// Category 3: two dots in the for loop. Make sure the optimization is not applied -// should NOT apply: two dots -// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot -// CHECK: triton_gpu.local_load -// CHECK-NEXT: triton_gpu.local_load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: tt.dot -// CHECK-NEXT: tt.load -// CHECK-NEXT: tt.load -// CHECK-NEXT: triton_gpu.local_store -// CHECK-NEXT: triton_gpu.local_store -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %6 = tt.dot %1, %2, %3 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} diff --git a/test/lib/Instrumentation/GPUHello.cpp b/test/lib/Instrumentation/GPUHello.cpp index 5c71857c8f36..3bee8ce90ced 100644 --- a/test/lib/Instrumentation/GPUHello.cpp +++ b/test/lib/Instrumentation/GPUHello.cpp @@ -61,7 +61,7 @@ bool GpuHello::runOnModule(Module &module) { PassPluginLibraryInfo getPassPluginInfo() { const auto callback = [](PassBuilder &pb) { - pb.registerOptimizerLastEPCallback([&](ModulePassManager &mpm, auto, auto) { + pb.registerOptimizerLastEPCallback([&](ModulePassManager &mpm, auto) { mpm.addPass(GpuHello()); return true; }); diff --git a/third_party/amd/backend/include/hsa/amd_hsa_elf.h b/third_party/amd/backend/include/hsa/amd_hsa_elf.h index 0656c9d99419..51aa389a0681 100644 --- a/third_party/amd/backend/include/hsa/amd_hsa_elf.h +++ b/third_party/amd/backend/include/hsa/amd_hsa_elf.h @@ -130,7 +130,6 @@ enum : unsigned { EF_AMDGPU_MACH_AMDGCN_GFX1151 = 0x04a, EF_AMDGPU_MACH_AMDGCN_GFX941 = 0x04b, EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c, - EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f, // First/last AMDGCN-based processors. EF_AMDGPU_MACH_AMDGCN_FIRST = EF_AMDGPU_MACH_AMDGCN_GFX600, diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index 6dbb0435e20c..a7395f86dc50 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -30,7 +30,6 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Dialect/Triton/IR/Traits.h" - // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" // clang-format on diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h index 9e174d545dd9..a49e442d3984 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h @@ -19,17 +19,6 @@ enum class ISAFamily { // Deduces the corresponding ISA family for the given target gfx |arch|. ISAFamily deduceISAFamily(llvm::StringRef arch); -// Here is a partial definition of DppCtrl enums. For the complete definition, -// please check: -// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939 -enum class DppCtrl : uint32_t { - QUAD_PERM_FIRST = 0, - ROW_SHL0 = 0x100, - ROW_SHR0 = 0x110, - BCAST15 = 0x142, - BCAST31 = 0x143 -}; - } // namespace mlir::triton::AMD #endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index d3ffaed2e8fc..b7ee4efc72d0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -115,6 +115,56 @@ struct LocalLoadOpConversion } }; +struct ConvertLayoutOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = op.getSrc(); + Value dst = op.getResult(); + auto srcTy = cast(src.getType()); + auto dstTy = cast(dst.getType()); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (isa(srcLayout) && + isa(dstLayout)) { + return lowerMfmaToDotOperand(op, adaptor, rewriter); + } + return failure(); + } + +private: + LogicalResult + lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + if (isMfmaToDotShortcut(srcTy, dstTy)) { + // vecSize is an number of sequential elements stored by one thread + // - For MFMA encoding (encoding of the result tensor of dot + // operation) it is 4 + // - For MFMA operand encoding it is + // dotOperandEncoding::kWidth, + // which is 4 in certain cases (e.g. fp16 and bfloat16 dtypes with kpack + // = 1) + // + // For cases where these two values are equal MFMA and MFMA operand + // layouts are the same. + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value view = + packLLElements(loc, getTypeConverter(), vals, rewriter, dstTy); + rewriter.replaceOp(op, view); + return success(); + } + return failure(); + } +}; } // namespace namespace mlir::triton::AMD { @@ -122,6 +172,7 @@ void populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index bce126ea4d72..cece47227ea0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -38,10 +38,8 @@ struct DecomposeUnsupportedAMDConversions triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); - auto isShortcut = - mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory)); - - triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut); + triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, + isMfmaToDotShortcut); /* -------------------------------- */ // Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op` diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 343dc7b3f37f..a45efd4a7971 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -694,32 +694,6 @@ struct AtomicCASOpConversion } }; -bool supportsGlobalAtomicF16PackedAndDpp(triton::AMD::ISAFamily isaFamily) { - return isaFamily == triton::AMD::ISAFamily::CDNA1 || - isaFamily == triton::AMD::ISAFamily::CDNA2 || - isaFamily == triton::AMD::ISAFamily::CDNA3; -} - -Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) { - assert(val.getType().isInteger(32)); - auto loc = val.getLoc(); - Value old = i32_val(0); - int rowMask = 0b1111; // enable all rows - int bankMask = 0b1111; // enable all banks - bool boundCtrl = false; - auto dppMovOp = rewriter.create( - loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl); - return dppMovOp.getResult(); -} - -Value shiftLeftI32ByDpp(PatternRewriter &rewriter, Value val) { - return generateI32DppMove(rewriter, val, 0x101); // shift left 1 lane -} - -Value shiftRightI32ByDpp(PatternRewriter &rewriter, Value val) { - return generateI32DppMove(rewriter, val, 0x111); // shift right 1 lane -} - struct AtomicRMWOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { @@ -791,36 +765,10 @@ struct AtomicRMWOpConversion // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); int numElems = 1; - Type packF16Ty = vec_ty(valueElemTy, 2); - - // In the case of unpaired f16 elements utilize dpp instructions to - // accelerate atomics. Here is an algorithm of lowering - // tt::atomicRmwOp(%ptr, %val, %mask): - // 0. Group thread by pairs. Master thread is (tid % 2 == 0); - // 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so - // all the masters recieve value from secondary threads; - // 2. Take into account parity in the %mask value, build control flow - // structures according to it; - // 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value; - // 4. All the threads send result of generated operation to (tid + 1) thread - // via dppUpdateOp shl, so all secondary thread also recieve their - // result. - // - // This approach enables us to use half the active threads committing atomic - // requests to avoid generating of code providing unified access to f16 - // element and reduce contantion. - bool useDppForPackedF16 = false; // tensor if (tensorTy) { auto valTy = cast(val.getType()); - bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16(); - unsigned availableVecSize = isF16Ty ? 2 : 1; - vec = std::min(vec, availableVecSize); - // Force F16 packing in the case it's not comming in as packed, but the - // ISA can support packed atomic instructions. - useDppForPackedF16 = - supportsGlobalAtomicF16PackedAndDpp(targetInfo.getISAFamily()) && - vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD; + vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); // mask numElems = tensorTy.getNumElements(); } @@ -828,49 +776,20 @@ struct AtomicRMWOpConversion auto tid = tid_val(); mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); - if (useDppForPackedF16) - mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0))); auto memOrdering = op.getSem(); auto atomicMemOrdering = getMemoryOrdering(memOrdering); auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; - retType = useDppForPackedF16 ? packF16Ty : retType; SmallVector resultVals(elemsPerThread); + const bool f16v2 = vec == 2 && valueElemTy.isF16(); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwPtr = ptrElements[i]; // TODO: in case llMask is zero we can create only one branch for all // elemsPerThread. Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; - Value operand; - if (useDppForPackedF16) { - // Move %val to left neighbour to proceed packed atomic further. - Value packedVal = null(packF16Ty); - packedVal = - insert_element(packF16Ty, packedVal, valElements[i], i32_val(0)); - // Pack to i32 type to simplify transaction - packedVal = bitcast(packedVal, i32_ty); - Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal); - // Unpack results back - Value unpackedDppRes = bitcast(dppMoveRes, packF16Ty); - operand = undef(packF16Ty); - operand = - insert_element(packF16Ty, operand, valElements[i], i32_val(0)); - operand = insert_element( - packF16Ty, operand, - extract_element(valueElemTy, unpackedDppRes, i32_val(0)), - i32_val(1)); - } else if (vec == 1) { - operand = valElements[i]; - } else { - operand = undef(vecTy); - for (size_t ii = 0; ii < vec; ++ii) - operand = - insert_element(vecTy, operand, valElements[i + ii], i32_val(ii)); - } - Value undefVal = undef(retType); // Build blocks to bypass the atomic instruction for ~rmwMask. auto *curBlock = rewriter.getInsertionBlock(); @@ -887,11 +806,25 @@ struct AtomicRMWOpConversion auto maybeKind = matchAtomicOp(atomicRmwAttr); // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient // atomics for MI-* series of AMD GPU. - Value atom = - rewriter - .create(loc, *maybeKind, rmwPtr, operand, - atomicMemOrdering, StringRef("agent")) - .getResult(); + Value atom = rewriter + .create( + loc, *maybeKind, rmwPtr, valElements[i], + atomicMemOrdering, StringRef("agent")) + .getResult(); + + // NV for the f16v2 case generates one packed instruction. We have to + // create two separate instructions since LLVM::AtomicRMWOp doesn't + // support this. Can be optimized out with rocdl.raw.buffer.atomic. + if (f16v2) { + Value atom2 = + rewriter + .create( + loc, *maybeKind, ptrElements[i + 1], valElements[i + 1], + atomicMemOrdering, StringRef("agent")) + .getResult(); + auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0)); + atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); + } if (!tensorTy) { if (atomicNeedsSharedMemory(op.getResult())) { Value atomPtr = @@ -904,25 +837,10 @@ struct AtomicRMWOpConversion rewriter.setInsertionPointToStart(endBlock); Value retVal = endBlock->getArgument(0); if (tensorTy) { - if (useDppForPackedF16) { - // Return packed to i32 result after atomic operation back from master - // lane. - auto packedRet = bitcast(retVal, i32_ty); - Value dppMovRes = shiftRightI32ByDpp(rewriter, packedRet); - // Unpack results back - Value unpackedDppRes = bitcast(dppMovRes, packF16Ty); - retVal = insert_element( - packF16Ty, retVal, - extract_element(valueElemTy, unpackedDppRes, i32_val(1)), - i32_val(1)); - resultVals[i] = - extract_element(valueElemTy, retVal, urem(tid, i32_val(2))); - } else { - for (int ii = 0; ii < vec; ++ii) { - resultVals[i + ii] = - vec == 1 ? retVal - : extract_element(valueElemTy, retVal, i32_val(ii)); - } + for (int ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = + vec == 1 ? retVal + : extract_element(valueElemTy, retVal, i32_val(ii)); } } else { if (!atomicNeedsSharedMemory(op.getResult())) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 525361fee603..3a40d73c2a7c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -5,7 +5,6 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -using mlir::triton::AMD::DppCtrl; namespace mlir::triton::AMD { namespace { @@ -104,22 +103,22 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleXor(loc, rewriter, val, i, getISAFamily()); + return LLVM::AMD::shuffleXor(loc, rewriter, val, i); } Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleUp(loc, rewriter, val, i, getISAFamily()); + return LLVM::AMD::shuffleUp(loc, rewriter, val, i); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } Value TargetInfo::programId(RewriterBase &rewriter, Location loc, @@ -127,184 +126,11 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc, return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); } -// Cast and sext values into specific-length int to meet the requirements of -// instructions like UpdateDpp or readlane if necessary. -static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc, - Value &val, Type fromType, - unsigned toBits) { - unsigned originalBits = fromType.getIntOrFloatBitWidth(); - Type toType = fromType; - - if (!fromType.isIntOrIndex()) { - val = bitcast(val, int_ty(originalBits)); - toType = int_ty(originalBits); - } - - if (originalBits < toBits) { - val = sext(int_ty(toBits), val); - toType = int_ty(toBits); - } - - return toType; -} - -// Trunc the value to specific length and then cast it to given type if -// necessary. This function is typically used in conjunction with -// castToAndSExtInt. -static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc, - Value val, Type valType, - unsigned fromBits) { - unsigned originalBits = valType.getIntOrFloatBitWidth(); - Value toVal = val; - - if (originalBits < fromBits) { - toVal = trunc(int_ty(originalBits), toVal); - } - - if (!valType.isIntOrIndex()) { - toVal = bitcast(toVal, valType); - } - - return toVal; -} - bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { - if (numLaneToReduce != 64) - return false; - - if (auto family = getISAFamily(); - family != ISAFamily::CDNA3 && family != ISAFamily::CDNA2) { - return false; - } - - Operation *reduxOp = op.getSingleCombiner(); - if (!reduxOp) - return false; - - auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src, - uint32_t dppCtrl, int rowMask, - int bankMask) -> Value { - // DPP has limited support for data types, so here we need to - // cast non-integer types or integer types shorter than 32 bits - // to int32, except for fp32. - Type actualType = valType; - if (!valType.isF32()) { - actualType = castToAndSExtInt(rewriter, loc, src, valType, 32); - } - - Value dppResult = - rewriter - .create(loc, actualType, src, src, - rewriter.getI32IntegerAttr(dppCtrl), - rewriter.getI32IntegerAttr(rowMask), - rewriter.getI32IntegerAttr(bankMask), - rewriter.getBoolAttr(true)) - .getRes(); - - if (!valType.isF32()) { - src = truncAndCastFromInt(rewriter, loc, src, valType, 32); - dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32); - } - - IRMapping mapping; - mapping.map(reduxOp->getOperand(0), src); - mapping.map(reduxOp->getOperand(1), dppResult); - return rewriter.clone(*reduxOp, mapping)->getResult(0); - }; - - for (int i = 0; i < acc.size(); i++) { - Value buf; - auto valType = acc[i].getType(); - - /* - Here's the implementation of full-wavefront reduction using dpp. - https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ - - Each step has a v_mov_dpp instruction following the redux op. In - some cases, the lower-level compiler could merge them into single - instruction. For example, v_mov_dpp + max => v_max_dpp. - - For gfx9, we have 64 threads per warp. These 64 threads are arranged - into 4 rows, with each row being 16 threads. Each 16 threads are arranged - further into 4 banks, with each bank being 4 threads. Overall it's in a - (row, bank, thread) structure. When shuffling, we use row/bank mask to - indicate which row/bank to participate. Then modifier like row_shr and - row_bcast means exact data movement schemes. In the following - instructions, taking row 0 as an example: - - Step 1: Right shift for 8 lanes. - lane 8-15 = redux(lane 0-7, lane 8-15) - - Step 2: Right shift for 4 lanes. - lane 12-15 = redux(lane 8-11, lane 12-15) - - Step 3: Right shift for 2 lanes. - lane 14-15 = redux(lane 12-13, lane 14-15) - - Step 4: Right shift for 1 lane. - lane 15 = redux(lane 14, lane 15) - - Step 5: Broadcast lane 15 of each row to all the lanes of its next row. - lane 16-31 = redux(lane 15, lane 16-31) - - Step 6: Broadcast lane 31 to lane 32-63. - lane 32-63 = redux(lane 31, lane 32-63) - - Now the reduction result is stored in lane 63. - - Step 7: Read the reduction result from lane 63 and broadcast with - readlane. - */ - - const int allRows = 0xf; - const int allBanks = 0xf; - - const uint32_t dppCtrlRowShr = static_cast(DppCtrl::ROW_SHR0); - - // row_shr:8 - buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr, - allRows, allBanks); - - // row_shr:4 - buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr, - allRows, allBanks); - - // row_shr:2 - buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr, - allRows, allBanks); - - // row_shr:1 - buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr, - allRows, allBanks); - - // row_bcast:15 row_mask:0xa - buf = createDppReduxOpWithBoundCtrl( - valType, buf, static_cast(DppCtrl::BCAST15), 0xa, allBanks); - - // row_bcast:31 - buf = createDppReduxOpWithBoundCtrl(valType, buf, - static_cast(DppCtrl::BCAST31), - allRows, allBanks); - - // Similarly, we need to cast data types for readlane instruction. - Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16); - - // Get reduction result from lane 63 - std::string intrinsic = "llvm.amdgcn.readlane"; - Value result = - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType, - ValueRange{buf, i32_val(63)}) - ->getResult(0); - - result = truncAndCastFromInt(rewriter, loc, result, valType, 16); - - acc[i] = result; - } - - return true; + return false; } void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp index 7ab6fd68a5d5..63fb972f7903 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp @@ -11,7 +11,6 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) { // CDNA ISA cases switch (kind) { - case llvm::AMDGPU::GK_GFX950: case llvm::AMDGPU::GK_GFX942: case llvm::AMDGPU::GK_GFX941: case llvm::AMDGPU::GK_GFX940: diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 0bd401f1993a..542b1ecbb7fb 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -8,8 +8,6 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" -using mlir::triton::AMD::DppCtrl; -using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; using mlir::triton::gpu::getFunctionType; @@ -73,9 +71,8 @@ Type castToVectorType(Type ty) { } // namespace namespace mlir::LLVM::AMD { -static Value shuffleCommon(Location loc, RewriterBase &rewriter, - ISAFamily isaFamily, Value val, Value i, - int strideInt, ShflKind mode, Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, + Value i, int strideInt, ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on @@ -87,8 +84,7 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, if (bits < 32) val = sext(i32_ty, val); - val = - shuffleCommon(loc, rewriter, isaFamily, val, i, strideInt, mode, clamp); + val = shuffleCommon(loc, rewriter, val, i, strideInt, mode, clamp); if (bits < 32) val = trunc(int_ty(bits), val); @@ -102,10 +98,8 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); - val0 = shuffleCommon(loc, rewriter, isaFamily, val0, i, strideInt, mode, - clamp); - val1 = shuffleCommon(loc, rewriter, isaFamily, val1, i, strideInt, mode, - clamp); + val0 = shuffleCommon(loc, rewriter, val0, i, strideInt, mode, clamp); + val1 = shuffleCommon(loc, rewriter, val1, i, strideInt, mode, clamp); vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); @@ -140,83 +134,13 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value stride = i32_val(32); Value lineId = xor_(threadId, stride); return bpermute(lineId); - } else if (strideInt == 16) { - Value offset = i32_val(0x401F); - return rewriter.create(loc, valType, val, offset); } else { - if (isaFamily != ISAFamily::CDNA2 && isaFamily != ISAFamily::CDNA3) { - // DPP is only supportted for CDNA2 and CDNA3 right now, so we fallback - // to ds_swizzle for other archs. - // - // This map facilates the butterfly shuffle pattern for a stride less - // than 16. The pattern stride is the key of the map. - DenseMap masks{ - {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; - Value offset = i32_val(masks[strideInt]); - return rewriter.create(loc, valType, val, offset); - } - - auto createDppOpWithoutBoundCtrl = [&](Value &old, Value &src, - uint32_t dppCtrl, uint32_t rowMask, - uint32_t bankMask) { - return rewriter.create( - loc, valType, old, src, rewriter.getI32IntegerAttr(dppCtrl), - rewriter.getI32IntegerAttr(rowMask), - rewriter.getI32IntegerAttr(bankMask), rewriter.getBoolAttr(false)); - }; - - const int allRows = 0xf; - const int allBanks = 0xf; - - switch (strideInt) { - case 1: { - // quad_perm: 1, 0, 3, 2 - uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); - std::array mask = {1, 0, 3, 2}; - for (int i = 0; i < mask.size(); i++) { - dppCtrl |= mask[i] << (i * 2); - } - return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, - allBanks); - } - case 2: { - // quad_perm: 2, 3, 0, 1 - uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); - std::array mask = {2, 3, 0, 1}; - for (int i = 0; i < mask.size(); i++) { - dppCtrl |= mask[i] << (i * 2); - } - return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, - allBanks); - } - case 4: { - // row_shr:4 bank_mask: 0xa - auto ret = createDppOpWithoutBoundCtrl( - val, val, 4 + static_cast(DppCtrl::ROW_SHR0), - allRows, 0xa) - .getRes(); - - // row_shl:4 bank_mask: 0x5 - return createDppOpWithoutBoundCtrl( - ret, val, 4 + static_cast(DppCtrl::ROW_SHL0), allRows, - 0x5); - } - case 8: { - // row_shr:8 bank_mask: 0xc - auto ret = createDppOpWithoutBoundCtrl( - val, val, 8 + static_cast(DppCtrl::ROW_SHR0), - allRows, 0xc) - .getRes(); - - // row_shl:8 bank_mask: 0x3 - return createDppOpWithoutBoundCtrl( - ret, val, 8 + static_cast(DppCtrl::ROW_SHL0), allRows, - 0x3); - } - default: - assert(false && - "bfly shfl with stride >= 16 should not be handled by dpp."); - } + // This map facilates the butterfly shuffle pattern for a stride less + // than 16. The pattern stride is the key of the map. + DenseMap masks{ + {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; + Value offset = i32_val(masks[strideInt]); + return rewriter.create(loc, valType, val, offset); } break; case ShflKind::up: { @@ -234,27 +158,22 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, return Value(); } -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, - ISAFamily isaFamily) { - return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, - ShflKind::bfly, i32_val(0x1f)); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { + return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::bfly, + i32_val(0x1f)); } -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, - ISAFamily isaFamily) { - return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, - ShflKind::up, i32_val(0x0)); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { + return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::up, + i32_val(0x0)); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, - ISAFamily isaFamily) { - return shuffleIdx(loc, rewriter, val, i32_val(i), isaFamily); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { + return shuffleIdx(loc, rewriter, val, i32_val(i)); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, - ISAFamily isaFamily) { - return shuffleCommon(loc, rewriter, isaFamily, val, i, 0, ShflKind::idx, - i32_val(0x1f)); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { + return shuffleCommon(loc, rewriter, val, i, 0, ShflKind::idx, i32_val(0x1f)); } Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index d150531848e3..123234fd4824 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -2,14 +2,12 @@ #define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_UTILITY_H #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" -#include "TritonAMDGPUToLLVM/TargetUtils.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" - namespace mlir::LLVM::AMD { const char predicatedLoad[] = "__predicated_load"; @@ -21,18 +19,10 @@ const char predicatedStoreCG[] = "__predicated_store_CG"; const char predicatedStoreCS[] = "__predicated_store_CS"; const char predicatedStoreWT[] = "__predicated_store_WT"; -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, - mlir::triton::AMD::ISAFamily isaFamily = - mlir::triton::AMD::ISAFamily::Unknown); -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, - mlir::triton::AMD::ISAFamily isaFamily = - mlir::triton::AMD::ISAFamily::Unknown); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, - mlir::triton::AMD::ISAFamily isaFamily = - mlir::triton::AMD::ISAFamily::Unknown); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, - mlir::triton::AMD::ISAFamily isaFamily = - mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, int axis); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index c3a69a5f9a2a..7da8083cfb92 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -8,7 +8,6 @@ add_triton_library(TritonAMDGPUTransforms MfmaGroup.cpp DEPENDS - TritonAMDGPUIR TritonAMDGPUTransformsIncGen TritonGPUIR ) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 9371c8b5f897..e122f15fd901 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -5,30 +5,23 @@ #include "mlir/IR/Verifier.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "llvm/ADT/STLExtras.h" +#include + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" using namespace mlir; namespace ttg = mlir::triton::gpu; - -//===----------------------------------------------------------------------===// -// Utility functions -//===----------------------------------------------------------------------===// - -// Return true if the given moduleOp contains a pure matmul problem; i.e., -// single dot in the main loop. -static bool isPureMatmulProblem(ModuleOp moduleOp) { - bool isMatmul = true; - bool foundLoop = false; - moduleOp.walk([&](scf::ForOp forOp) -> void { - int counter = 0; - forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); - isMatmul = (isMatmul && (counter == 1)); - foundLoop = true; - }); - return foundLoop && isMatmul; +namespace tt = mlir::triton; + +static bool isLocalLoadOrDotLayoutConversion(Operation *op) { + if (isa(op)) + return true; + if (auto cvt = dyn_cast(op)) + return isa(cvt.getType().getEncoding()); + return false; } // Search through block to find earliest insertion point for move op. This can @@ -68,311 +61,194 @@ findEarlyInsertionPoint(Block *block, Operation *move) { return ipnt; } -// Return the first user in the same block of the given op. If the user is in a -// nested block then return the op owning the block. Return nullptr if not -// existing. -static Operation *getFirstUseInSameBlock(Operation *op) { - SmallVector usersInSameBlock; - for (auto user : op->getUsers()) { - if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) - usersInSameBlock.push_back(ancestor); - } - auto minOpIt = - llvm::min_element(usersInSameBlock, [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); - }); - return minOpIt != usersInSameBlock.end() ? *minOpIt : nullptr; -} - // Check if the operation opInsideLoop is inside any scf::ForOp and // opOutsideLoop is not inside the same loop. -static bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, - mlir::Operation *opOutsideLoop) { +bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { scf::ForOp parentForOp = opInsideLoop->getParentOfType(); return parentForOp && !parentForOp->isAncestor(opOutsideLoop); } -//===----------------------------------------------------------------------===// -// Reorder mechanisms -//===----------------------------------------------------------------------===// - -// Sink dot layout conversions into loops to decrease register pressure when -// possible. -static void sinkDotConversion(ModuleOp moduleOp) { - DenseMap opToMove; - moduleOp.walk([&](ttg::ConvertLayoutOp op) { - Attribute encoding = op.getType().getEncoding(); - if (!isa_and_nonnull(encoding)) - return; - if (!op->hasOneUse()) - return; - Operation *user = *op->getUsers().begin(); - if (user->getParentOfType() == - op->getParentOfType()) - return; - opToMove[op] = user; - }); - - for (auto &kv : opToMove) - kv.first->moveBefore(kv.second); -} - -// Adjust the placement of shared memory writes and reads to immediately follow -// the definition of their operands in case where shared memory write is in the -// loop but its operand is not. -// -// This is a heuristic driven by optimizing fused attention by hoisting Q tensor -// shared memory read/write operations outside of the loop, as Q is a loop -// invariant and can be loaded once before entering the loop. But it should be -// generally applicable. -// -// There are two possible patterns for this adjustment depending on whether the -// write to shared memory is performed using an optional `local_alloc` argument -// or a `local_store` instruction. -// -// 1) %1 = some_op ... (typically a load or an operation that scales the tensor -// after loading) -// %2 = local_alloc %1 -// %3 = local_load %2 -// -// 2) %1 = some_op ... -// %2 = local_alloc -// %3 = local_store %1, %2 -// %4 = local_load %2 -static void hoistLocalLoad(ModuleOp moduleOp) { - moduleOp.walk([&](ttg::LocalLoadOp localLoad) { - auto localAlloc = localLoad.getSrc().getDefiningOp(); - if (!localAlloc) - return; - - // Case when localAlloc has operands - if (localAlloc->getNumOperands() == 1) { - if (!localAlloc->hasOneUse()) - return; - - auto srcTensorOp = localAlloc.getSrc().getDefiningOp(); - // Check if localAlloc is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) - return; - - localAlloc->moveAfter(srcTensorOp); - localLoad->moveAfter(localAlloc); - return; - } - - // Case when localAlloc has no operands - assert(localAlloc->getNumOperands() < 1); - auto allocVal = localAlloc->getResult(0); - - // Check if the localAlloc has exactly two uses (localStore and localLoad) - int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); - if (numUses != 2) - return; - - // localStore comes before localLoad in block. - Operation *localStore = getFirstUseInSameBlock(localAlloc); - if (!isa(localStore)) - return; - - auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); - // Check if localStore is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { - return; +class TritonAMDGPUReorderInstructionsPass + : public TritonAMDGPUReorderInstructionsBase< + TritonAMDGPUReorderInstructionsPass> { +public: + TritonAMDGPUReorderInstructionsPass() = default; + + Operation *getFirstUse(Operation *op) { + std::vector users; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + users.push_back(ancestor); } - - localAlloc->moveAfter(srcTensorOp); - localStore->moveAfter(localAlloc); - localLoad->moveAfter(localStore); - }); -} - -// Sink conversion after the last dealloc but before the first use in its block. -// This helps to avoid unnecessary shared memory allocation. -static void moveDownCoversion(ModuleOp moduleOp) { - SmallVector convertOps; - moduleOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); }); - - for (auto op : convertOps) { - Operation *user = getFirstUseInSameBlock(op); - for (auto it = Block::iterator(op), ie = op->getBlock()->end(); - it != ie && &*it != user; ++it) - if (isa(&*it)) - op->moveAfter(&*it); + auto minOpIt = std::min_element(users.begin(), users.end(), + [](mlir::Operation *a, mlir::Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != users.end() ? *minOpIt : nullptr; } -} - -// Move transpositions just after their definition. -static void moveUpTranspose(ModuleOp moduleOp) { - SmallVector transOps; - moduleOp.walk([&](triton::TransOp op) { transOps.push_back(op); }); - - for (auto op : transOps) - if (Operation *argOp = op.getSrc().getDefiningOp()) - op->moveAfter(argOp); -} - -// Schedule global load and local store ops for better GEMM performance. -static void scheduleGlobalLoadLocalStore(ModuleOp m) { - SmallVector moveOps; - // Move global loads early to prefetch. This may increase register pressure - // but it enables issuing global loads early. - m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); - // Move local_stores early if dependence distance greater than one iteration. - // Best perf on GEMM when these precede global loads. - m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); - for (auto op : llvm::reverse(moveOps)) { - // Gather use-def chain in block. - Block *block = op->getBlock(); - bool leadsToLoad = false; - SetVector backwardSet; + void runOnOperation() override { + ModuleOp m = getOperation(); - BackwardSliceOptions options; - options.omitBlockArguments = true; - options.inclusive = false; - options.filter = [&](Operation *defOp) -> bool { - Block *defBlock = defOp->getBlock(); - if (!block->findAncestorOpInBlock(*defOp)) - return false; - // Check for a `load` dependent path. - leadsToLoad |= isa(defOp); - // Only move ops residing in the same block. - return defBlock == block; - }; - mlir::getBackwardSlice(op, &backwardSet, options); - backwardSet.insert(op); + // Sink shared memory loads and layout conversions into loops to decrease + // register pressure when possible. + DenseMap opToMove; + m.walk([&](Operation *op) { + if (!isLocalLoadOrDotLayoutConversion(op)) + return; + if (!op->hasOneUse()) + return; + Operation *user = *op->getUsers().begin(); + if (user->getParentOfType() == + op->getParentOfType()) + return; + opToMove.insert({op, user}); + }); + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); + opToMove.clear(); + + // Adjust the placement of LDS writes and reads to immediately follow the + // definition of their operands in case where LDS write is in the + // loop but it's operand is not. This is a heuristic for optimizing fused + // attention by hoisting Q tensor LDS read/write operations outside of the + // loop, as Q is a loop invariant and can be loaded once before entering the + // loop. + // There are two possible patterns for this adjustment depending on + // whether the write to LDS is performed using an optional `local_alloc` + // argument or a `local_store` instruction. + // + // clang-format off + // + // 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading) + // %2 = local_alloc %1 + // %3 = local_load %2 + // + // 2) %1 = some_op ... + // %2 = local_alloc + // %3 = local_store %1, %2 + // %4 = local_load %2 + // + // clang-format on + m.walk([&](ttg::LocalLoadOp localLoad) { + auto localAlloc = localLoad.getSrc().getDefiningOp(); + if (!localAlloc) + return; - // Don't move a local_store if its source is a load from - // the same iteration. - if (isa(op) && leadsToLoad) - continue; + // Case when localAlloc has operands + if (localAlloc->getNumOperands() == 1) { + if (!localAlloc->hasOneUse()) + return; - auto ipoint = findEarlyInsertionPoint(block, op); - // Remove ops that already precede the insertion point. This is done - // before moves happen to avoid `Operation::isBeforeInBlock` N^2 - // complexity. + auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) { + return; + } - SmallVector dfg = backwardSet.takeVector(); - if (ipoint != block->end()) { - // Move ops to insertion point. - llvm::erase_if( - dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveAfter(block, ipoint); - } else { - // Move ops to block begin. - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveBefore(block, block->begin()); - } - } -} + localAlloc->moveAfter(srcTensorOp); + localLoad->moveAfter(localAlloc); + return; + } -/** - * Sched-load optimization for matmul kernels with large tile sizes - * The basic idea of sched-load optimization is to sink the 2nd tt.load - * after local_load so that global_load instructions can be interleaved with - * mfma's. This can help hide the issue latency of global_load instructions - * and improve performance on MI300X. - * - * It's assumed that the IR before this optimization has the following - * structure: - * ```mlir - * scf.for .. - * { - * tileA = tt.load a_ptr - * tileB = tt.load b_ptr - * opA = local_load bufferA - * opB = local_load bufferB - * res = tt.dot opA, opB - * local_store tileA, bufferA - * local_store tileB, bufferB - * } - * ``` - * After this optimization, the IR is transformed to - * ```mlir - * scf.for .. - * { - * tileA = tt.load a_ptr - * opA = local_load bufferA - * opB = local_load bufferB - * tileB = tt.load b_ptr <-- 2nd tt.load is sinked here - * res = tt.dot opA, opB - * local_store tileA, bufferA - * local_store tileB, bufferB - * } - * ``` - * For now, we don't have a perfect hueristic about when should this - * optimization be applied. Therefore, we implement a simple hueristic that - * this is applied when the tile size of A and B are large enough, i.e. - * nonKDim >= 128 and kDim >= 64. And also this is only applied for typical - * matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We - * are experimenting how to better control instruction scheduling and enable - * such optimizations. - */ -static void sinkSecondLoad(ModuleOp m) { - m.walk([&](scf::ForOp forOp) -> void { - SetVector loadOps; - triton::DotOp dotOp; - for (Operation &op : forOp) { - if (auto loadOp = dyn_cast(&op)) - loadOps.insert(loadOp); - if (auto curOp = dyn_cast(&op)) - dotOp = curOp; - } - // Only apply the optimization when there are 2 load's in the loop - if (loadOps.size() != 2) - return; - // Only apply the optimization when tile size is large enough - // 1. nonKDim >= 128 - // 2. kDim >= 64 - auto ldAOp = loadOps[0]; - auto tileAShape = cast(ldAOp.getType()).getShape(); - auto ldBOp = loadOps[1]; - auto tileBShape = cast(ldBOp.getType()).getShape(); - if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128)) - return; - // Only apply the optimization when the moving is legal - // 1. Make sure the 2nd loadOp is before the dot - // 2. Make sure the first user of the 2nd loadOp is after the dot. - bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp); - auto firstUser = *ldBOp.getResult().getUsers().begin(); - bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser); - if (isBeforeDotOp && firstUserAfterDotOp) - // move ldBOp right before tt.dot - ldBOp->moveBefore(dotOp); - }); -} + // Case when localAlloc has no operands + assert(localAlloc->getNumOperands() < 1); + auto allocVal = localAlloc->getResult(0); -//===----------------------------------------------------------------------===// -// Pass definition -//===----------------------------------------------------------------------===// + // Check if the localAlloc has exactly two uses (localStore and localLoad) + int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); + if (numUses != 2) + return; -#define GEN_PASS_CLASSES -#include "TritonAMDGPUTransforms/Passes.h" + // localStore comes before localLoad in block. + Operation *localStore = getFirstUse(localAlloc); + if (!isa(localStore)) + return; -namespace { -struct TritonAMDGPUReorderInstructionsPass - : public TritonAMDGPUReorderInstructionsBase< - TritonAMDGPUReorderInstructionsPass> { - void runOnOperation() override { - ModuleOp m = getOperation(); + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { + return; + } - hoistLocalLoad(m); + localAlloc->moveAfter(srcTensorOp); + localStore->moveAfter(localAlloc); + localLoad->moveAfter(localStore); + }); - sinkDotConversion(m); - moveDownCoversion(m); + // Sink conversion after the last dealloc but before the first use ancestor + // in its block. This helps to avoid unnecessary shared memory allocation. + m.walk([&](triton::gpu::ConvertLayoutOp op) { + auto curr = mlir::Block::iterator(op); + for (; &*curr != getFirstUse(op); curr++) + if (isa(&*curr)) + op->moveAfter(&*curr); + }); - moveUpTranspose(m); + // Move transpositions just after their definition. + m.walk([&](triton::TransOp op) { + if (Operation *argOp = op.getSrc().getDefiningOp()) + op->moveAfter(argOp); + }); - if (isPureMatmulProblem(m)) { - scheduleGlobalLoadLocalStore(m); - sinkSecondLoad(m); + SmallVector moveOps; + // Move global loads early to prefetch. This may increase register pressure + // but it enables issuing global loads early. + m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + // Move local_stores early if dependence distance greater than + // one iteration. + // Best perf on GEMM when these precede global loads. + m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + + for (auto op : llvm::reverse(moveOps)) { + // Gather use-def chain in block. + Block *block = op->getBlock(); + bool leadsToLoad = false; + SetVector backwardSet; + + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.inclusive = false; + options.filter = [&](Operation *defOp) -> bool { + Block *defBlock = defOp->getBlock(); + if (!block->findAncestorOpInBlock(*defOp)) + return false; + // Check for a `load` dependent path. + leadsToLoad |= isa(defOp); + // Only move ops residing in the same block. + return defBlock == block; + }; + mlir::getBackwardSlice(op, &backwardSet, options); + backwardSet.insert(op); + + // Don't move a local_store if its source is a load from + // the same iteration. + if (isa(op) && leadsToLoad) + continue; + + auto ipoint = findEarlyInsertionPoint(block, op); + // Remove ops that already precede the insertion point. This is done + // before moves happen to avoid `Operation::isBeforeInBlock` N^2 + // complexity. + + SmallVector dfg = backwardSet.takeVector(); + if (ipoint != block->end()) { + // Move ops to insertion point. + llvm::erase_if( + dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveAfter(block, ipoint); + } else { + // Move ops to block begin. + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveBefore(block, block->begin()); + } } } }; -} // namespace std::unique_ptr mlir::createTritonAMDGPUReorderInstructionsPass() { return std::make_unique(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 96289bbb2e47..71fd3c0cd4e7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -631,6 +631,46 @@ struct ConvertLayoutOpConversion convertMMAV3To8BitsDotOperand(op, adaptor, rewriter); return success(); } + + if (isMmaToDotShortcut(srcTy, dstTy)) { + // get source values + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + unsigned elems = getTotalElemsPerThread(srcTy); + Type elemTy = + this->getTypeConverter()->convertType(srcTy.getElementType()); + // for the destination type, we need to pack values together + // so they can be consumed by tensor core operations + SmallVector vecVals; + // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer + // instructions to pack & unpack sub-word integers. A workaround is to + // store the results of ldmatrix in i32 + auto elemSize = elemTy.getIntOrFloatBitWidth(); + if (auto intTy = dyn_cast(elemTy) && elemSize <= 16) { + auto fold = 32 / elemSize; + for (unsigned i = 0; i < elems; i += fold) { + Value val = i32_val(0); + for (unsigned j = 0; j < fold; j++) { + auto ext = + shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); + val = or_(i32_ty, val, ext); + } + vecVals.push_back(bitcast(val, i32_ty)); + } + } else { + unsigned vecSize = std::max(32 / elemSize, 1); + Type vecTy = vec_ty(elemTy, vecSize); + for (unsigned i = 0; i < elems; i += vecSize) { + Value packed = rewriter.create(loc, vecTy); + for (unsigned j = 0; j < vecSize; j++) + packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); + vecVals.push_back(bitcast(packed, i32_ty)); + } + } + Value view = + packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy); + rewriter.replaceOp(op, view); + return success(); + } return failure(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index 36b14e270b27..cf0ddc248dd1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -70,18 +70,10 @@ struct DecomposeUnsupportedConversions : public mlir::triton::impl::DecomposeUnsupportedNVIDIAConversionsBase< DecomposeUnsupportedConversions> { void runOnOperation() override { - // FIXME [Dot LL] - // Remove the decomposeTensorCoreToDotLayoutConversion class entirely after - // we have enabled the new layout conversion for all the cases. - auto nvidiaShortCutFn = [&](RankedTensorType srcTy, - RankedTensorType dstTy) { - return matchMmaV3AndDotOperandLayout(srcTy, dstTy) || - cvtReordersRegisters(srcTy, dstTy); - }; ModuleOp mod = getOperation(); triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - nvidiaShortCutFn); + isMmaToDotShortcut); triton::gpu::decomposeBlockedToDotLayoutConversion(mod); mlir::RewritePatternSet patterns(&getContext()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index d1cef15a354e..75f9354104b1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -93,12 +93,20 @@ static std::optional matchReduxKind(triton::ReduceOp op, int computeCapability) { if (computeCapability < 80) return std::nullopt; - Operation *reduceOp = op.getSingleCombiner(); - if (!reduceOp) + if (op.getNumOperands() != 1 || op.getNumResults() != 1) + return std::nullopt; + Block *block = &(*op.getCombineOp().begin()); + Operation *yield = block->getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) return std::nullopt; auto intType = dyn_cast(reduceOp->getResultTypes()[0]); if (!intType || intType.getWidth() > 32) return std::nullopt; + if (reduceOp->getOperand(0) != block->getArgument(0) || + reduceOp->getOperand(1) != block->getArgument(1)) + return std::nullopt; if (isa(reduceOp)) return NVVM::ReduxKind::ADD; if (isa(reduceOp)) diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index d4c15bbad03f..fd65233e5c6b 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -41,9 +41,9 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); } - DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, - ArrayRef warps) { - auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, {1, 0}); + DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef warps, + ArrayRef order) { + auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order); return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); } @@ -301,19 +301,6 @@ TEST_F(LinearLayoutConversionsTest, Blocked4D) { {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); } -TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { - EXPECT_EQ(toLinearLayout({16, 16}, - mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), - LinearLayout( - { - {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, - {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, - {S("warp"), {}}, - {S("block"), {}}, - }, - {S("dim0"), S("dim1")})); -} - TEST_F(LinearLayoutConversionsTest, MMAv2_32x32) { EXPECT_EQ(toLinearLayout({32, 32}, mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), @@ -515,7 +502,7 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { - EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1})), + EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), LinearLayout( { {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, @@ -524,7 +511,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1})), + EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), LinearLayout( { {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, @@ -537,7 +524,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { EXPECT_EQ( - toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1})), + toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), @@ -547,7 +534,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1})), + EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), @@ -567,7 +554,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1})), + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index 897172fd6d34..f006447002ef 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -747,39 +747,6 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); } -TEST_F(LinearLayoutTest, Resize) { - auto init = LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}}}, - {S("in1"), {{1, 0}, {2, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")}); - EXPECT_EQ(init.resize(S("in0"), 8), - LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}, {0, 0}}}, - {S("in1"), {{1, 0}, {2, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")})); - EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}}}, - {S("in1"), {{1, 0}, {2, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")})); - EXPECT_EQ(init.resize(S("in1"), 8), - LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}}}, - {S("in1"), {{1, 0}, {2, 0}, {0, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")})); -} - } // anonymous namespace } // namespace mlir::triton