diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index df6029db0de2..ae517912fbb4 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -218,6 +218,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy); +// Check if MFMA layout can be converted to the dot operand +// layout using warp shuffle. +bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, + RankedTensorType dstTy); + // TODO: Move utility functions that belong to ConvertLayoutOp to class // ConvertLayoutOpHelper in the future bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 6166e1019901..5fd87e4c0169 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" +#include "triton/Conversion/MLIRTypes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -632,6 +633,25 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, return ans; } +bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto mfmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (!mfmaLayout || !dotOperandLayout) + return false; + + // Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case + return dotOperandLayout.getParent() == mfmaLayout && + dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && + dotOperandLayout.getKWidth() == 8 && + getContigPerThread(mfmaLayout)[1] == 4 && + ((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) || + (mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) && + triton::type::isFloat8(srcTy.getElementType()) && + triton::type::isFloat8(dstTy.getElementType()) && + mfmaLayout.getWarpsPerCTA()[1] == 1; +} + // We get the smallest submap of srcTy^{-1} * dstTy that is not the identity // under kBlock, kWarp or kLane (in that order). The idea here is that if we // have a transformation that's the identity on kBlock, we don't need to use @@ -730,7 +750,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !matchMmaV3AndDotOperandLayout(srcTy, dstTy); + !matchMmaV3AndDotOperandLayout(srcTy, dstTy) && + // to be removed when generalized warp shuffle conversions + // are ready: + !matchMFMAAndDotOperandShuffleCase(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index aab97c7dd2b3..f0026c199324 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -409,6 +409,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return failure(); } + // The following check can be removed when generalized warp shuffle + // conversions are ready: + if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) { + return failure(); + } + assert(cvtNeedsSharedMemory(srcTy, dstTy)); SmallVector inVals = diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index a2c8f48718d9..bcbc7eff590e 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s +// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s #mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> #dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> @@ -27,3 +27,191 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_f8_mfma32 + tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]] + + // Input (8 values): (vec0, vec1) + // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): + // resVec0 resVec1 + // lanes 0-31: (vec0 , vec0 >> 32) (mask0=1) + // lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0) + + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]] + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]] + + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + + // CHECK: llvm.return + %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma32 + tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return + %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_f8_mfma16 + tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + + // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] + // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] + // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]] + // CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]] + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]] + // CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]] + + // Input (8 values): (vec0, vec1) + // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): + // resVec0 resVec1 + // lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1) + // lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0) + // lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1) + // lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0) + + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8> + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8> + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + + // CHECK: llvm.return + %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma16 + tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return + %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 208483beb8fc..3b61fb8cc467 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -116,6 +116,158 @@ struct LocalLoadOpConversion } }; +struct ConvertLayoutOpMFMAToDotOpConversion + : public ConvertOpToLLVMPattern { +public: + explicit ConvertLayoutOpMFMAToDotOpConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(op.getSrc().getType()); + auto dstType = cast(op.getType()); + + if (!matchMFMAAndDotOperandShuffleCase(srcType, dstType)) + return failure(); + + auto loc = op.getLoc(); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + if (inVals.empty() || inVals.size() % 8 != 0) + return failure(); + + auto mfmaLayout = dyn_cast(srcType.getEncoding()); + assert((mfmaLayout.getMDim() == 16 || mfmaLayout.getMDim() == 32) && + "Expected MFMA size 16 or 32"); + assert(triton::gpu::getWarpSize(mfmaLayout) == 64 && + "Expected warp size 64 for MFMA"); + + auto elemTy = int_ty(8); + auto vecTy = vec_ty(elemTy, 4); + + Value c16 = i32_val(16); + Value c32 = i32_val(32); + Value c48 = i32_val(48); + Value c64 = i32_val(64); + + Value threadId = tid_val(); + Value laneId = urem(threadId, c64); + + Value mask0 = icmp_slt(laneId, c32); + Value mask1 = icmp_slt(urem(laneId, c32), c16); + + Value addrShift16 = urem(add(laneId, c16), c64); + Value addrShift32 = urem(add(laneId, c32), c64); + Value addrShift48 = urem(add(laneId, c48), c64); + + SmallVector outVals; + for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) { + Value vec0 = undef(vecTy); + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + vec0 = + insert_element(vecTy, vec0, inVals[startIdx + vIdx], i32_val(vIdx)); + } + Value vec1 = undef(vecTy); + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + vec1 = insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4], + i32_val(vIdx)); + } + + Value resVec0, resVec1; + if (mfmaLayout.getMDim() == 32) { + /* + Using wave shuffle to convert layouts (32x32x16 case): + 1) Input MMA layout (32x32, fp8, 16 values): + _____________________________________________________________ + |(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)| + | ... ... | + |(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)| + |_____________________________________________________________| + + 2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each): + ____________________________________________________________ ___ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) || + | ... ... ||... + |(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) || + |____________________________________________________________||___ + */ + + Value shflVec0 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), + vecTy); + Value shflVec1 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), + vecTy); + + resVec0 = select(mask0, vec0, shflVec1); + resVec1 = select(mask0, shflVec0, vec1); + } else if (mfmaLayout.getMDim() == 16) { + /* + 16x16x32 case: + 1) Input MMA layout (two 16x16, fp8, 4 values each): + _________________________________________________________ ___________ + |(t0 v0 v1 v2 v3) (t16 v0 v1 v2 v3) ... (t48 v0 v1 v2 v3)||(t0 v4 ... + | ... ... || ... + |(t15 v0 v1 v2 v3) (t31 v0 v1 v2 v3) ... (t63 v0 v1 v2 v3)||(t15 v4 ... + |_________________________________________________________||___________ + + 2) Output Dot operand layout (16x32 tile, fp8, 8 values): + ________________________________________________________________ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) ... (t48 v0 v1 v2 v3 v4 v5 v6 v7) | + | ... ... | + |(t15 v0 v1 v2 v3 v4 v5 v6 v7) ... (t63 v0 v1 v2 v3 v4 v5 v6 v7) | + |________________________________________________________________| + */ + + Value shflVec0_16 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift16), + vecTy); + Value shflVec0_32 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), + vecTy); + Value shflVec1_32 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), + vecTy); + Value shflVec1_48 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift48), + vecTy); + + resVec0 = select(mask0, select(mask1, vec0, shflVec0_16), + select(mask1, shflVec1_32, shflVec1_48)); + resVec1 = select(mask0, select(mask1, shflVec0_16, shflVec0_32), + select(mask1, shflVec1_48, vec1)); + } + + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + outVals.push_back(extract_element(elemTy, resVec0, i32_val(vIdx))); + } + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + outVals.push_back(extract_element(elemTy, resVec1, i32_val(vIdx))); + } + } + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + +protected: + const TargetInfoBase &targetInfo; +}; + } // namespace namespace mlir::triton::AMD { @@ -124,5 +276,7 @@ void populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, + benefit); } } // namespace mlir::triton::AMD