diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index ae517912fbb4..df6029db0de2 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -218,11 +218,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy); -// Check if MFMA layout can be converted to the dot operand -// layout using warp shuffle. -bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, - RankedTensorType dstTy); - // TODO: Move utility functions that belong to ConvertLayoutOp to class // ConvertLayoutOpHelper in the future bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 5fd87e4c0169..6166e1019901 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -10,7 +10,6 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" -#include "triton/Conversion/MLIRTypes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -633,25 +632,6 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, return ans; } -bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, - RankedTensorType dstTy) { - auto mfmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); - if (!mfmaLayout || !dotOperandLayout) - return false; - - // Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case - return dotOperandLayout.getParent() == mfmaLayout && - dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && - dotOperandLayout.getKWidth() == 8 && - getContigPerThread(mfmaLayout)[1] == 4 && - ((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) || - (mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) && - triton::type::isFloat8(srcTy.getElementType()) && - triton::type::isFloat8(dstTy.getElementType()) && - mfmaLayout.getWarpsPerCTA()[1] == 1; -} - // We get the smallest submap of srcTy^{-1} * dstTy that is not the identity // under kBlock, kWarp or kLane (in that order). The idea here is that if we // have a transformation that's the identity on kBlock, we don't need to use @@ -750,10 +730,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !matchMmaV3AndDotOperandLayout(srcTy, dstTy) && - // to be removed when generalized warp shuffle conversions - // are ready: - !matchMFMAAndDotOperandShuffleCase(srcTy, dstTy); + !matchMmaV3AndDotOperandLayout(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index f0026c199324..aab97c7dd2b3 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -409,12 +409,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return failure(); } - // The following check can be removed when generalized warp shuffle - // conversions are ready: - if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) { - return failure(); - } - assert(cvtNeedsSharedMemory(srcTy, dstTy)); SmallVector inVals = diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index bcbc7eff590e..a2c8f48718d9 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s +// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s #mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> #dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> @@ -27,191 +27,3 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } - -// ----- - -#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> -#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: mfma_dot_cvt_f8_mfma32 - tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { - // CHECK-NOT: store - // CHECK-NOT: load - - // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] - // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] - - // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) - - // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x - // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] - // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] - - // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] - // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - - // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> - // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> - - // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] - // CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]] - - // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] - // CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]] - - // Input (8 values): (vec0, vec1) - // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): - // resVec0 resVec1 - // lanes 0-31: (vec0 , vec0 >> 32) (mask0=1) - // lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0) - - // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]] - // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]] - - // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) - // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> - // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 - // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> - - // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] - // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] - - // CHECK: llvm.return - %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> - tt.return - } -} - -// ----- - -#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> -#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: mfma_dot_cvt_bf8_mfma32 - tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) { - // CHECK-NOT: store - // CHECK-NOT: load - // CHECK: rocdl.ds_bpermute - // CHECK: llvm.return - %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> - tt.return - } -} - -// ----- - -#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> -#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: mfma_dot_cvt_f8_mfma16 - tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { - // CHECK-NOT: store - // CHECK-NOT: load - - // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] - // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] - - // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) - // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32) - // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) - - // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x - // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] - // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] - - // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] - // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] - - // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] - // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - - // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] - // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - - // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] - // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - - // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> - // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> - - // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]] - // CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] - // CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]] - - // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] - // CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]] - - // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] - // CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]] - - // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]] - // CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] - // CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]] - - // Input (8 values): (vec0, vec1) - // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): - // resVec0 resVec1 - // lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1) - // lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0) - // lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1) - // lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0) - - // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8> - // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8> - // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> - - // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8> - // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8> - // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> - - // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) - // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> - // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 - // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> - - // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] - // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] - - // CHECK: llvm.return - %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> - tt.return - } -} - -// ----- - -#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> -#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: mfma_dot_cvt_bf8_mfma16 - tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) { - // CHECK-NOT: store - // CHECK-NOT: load - // CHECK: rocdl.ds_bpermute - // CHECK: llvm.return - %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> - tt.return - } -} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 3b61fb8cc467..208483beb8fc 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -116,158 +116,6 @@ struct LocalLoadOpConversion } }; -struct ConvertLayoutOpMFMAToDotOpConversion - : public ConvertOpToLLVMPattern { -public: - explicit ConvertLayoutOpMFMAToDotOpConversion( - LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(typeConverter, - benefit), - targetInfo(targetInfo) {} - - LogicalResult - matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcType = cast(op.getSrc().getType()); - auto dstType = cast(op.getType()); - - if (!matchMFMAAndDotOperandShuffleCase(srcType, dstType)) - return failure(); - - auto loc = op.getLoc(); - - SmallVector inVals = - unpackLLElements(loc, adaptor.getSrc(), rewriter); - if (inVals.empty() || inVals.size() % 8 != 0) - return failure(); - - auto mfmaLayout = dyn_cast(srcType.getEncoding()); - assert((mfmaLayout.getMDim() == 16 || mfmaLayout.getMDim() == 32) && - "Expected MFMA size 16 or 32"); - assert(triton::gpu::getWarpSize(mfmaLayout) == 64 && - "Expected warp size 64 for MFMA"); - - auto elemTy = int_ty(8); - auto vecTy = vec_ty(elemTy, 4); - - Value c16 = i32_val(16); - Value c32 = i32_val(32); - Value c48 = i32_val(48); - Value c64 = i32_val(64); - - Value threadId = tid_val(); - Value laneId = urem(threadId, c64); - - Value mask0 = icmp_slt(laneId, c32); - Value mask1 = icmp_slt(urem(laneId, c32), c16); - - Value addrShift16 = urem(add(laneId, c16), c64); - Value addrShift32 = urem(add(laneId, c32), c64); - Value addrShift48 = urem(add(laneId, c48), c64); - - SmallVector outVals; - for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) { - Value vec0 = undef(vecTy); - for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - vec0 = - insert_element(vecTy, vec0, inVals[startIdx + vIdx], i32_val(vIdx)); - } - Value vec1 = undef(vecTy); - for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - vec1 = insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4], - i32_val(vIdx)); - } - - Value resVec0, resVec1; - if (mfmaLayout.getMDim() == 32) { - /* - Using wave shuffle to convert layouts (32x32x16 case): - 1) Input MMA layout (32x32, fp8, 16 values): - _____________________________________________________________ - |(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)| - | ... ... | - |(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)| - |_____________________________________________________________| - - 2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each): - ____________________________________________________________ ___ - |(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) || - | ... ... ||... - |(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) || - |____________________________________________________________||___ - */ - - Value shflVec0 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), - vecTy); - Value shflVec1 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), - vecTy); - - resVec0 = select(mask0, vec0, shflVec1); - resVec1 = select(mask0, shflVec0, vec1); - } else if (mfmaLayout.getMDim() == 16) { - /* - 16x16x32 case: - 1) Input MMA layout (two 16x16, fp8, 4 values each): - _________________________________________________________ ___________ - |(t0 v0 v1 v2 v3) (t16 v0 v1 v2 v3) ... (t48 v0 v1 v2 v3)||(t0 v4 ... - | ... ... || ... - |(t15 v0 v1 v2 v3) (t31 v0 v1 v2 v3) ... (t63 v0 v1 v2 v3)||(t15 v4 ... - |_________________________________________________________||___________ - - 2) Output Dot operand layout (16x32 tile, fp8, 8 values): - ________________________________________________________________ - |(t0 v0 v1 v2 v3 v4 v5 v6 v7) ... (t48 v0 v1 v2 v3 v4 v5 v6 v7) | - | ... ... | - |(t15 v0 v1 v2 v3 v4 v5 v6 v7) ... (t63 v0 v1 v2 v3 v4 v5 v6 v7) | - |________________________________________________________________| - */ - - Value shflVec0_16 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec0, int_ty(32)), addrShift16), - vecTy); - Value shflVec0_32 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), - vecTy); - Value shflVec1_32 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), - vecTy); - Value shflVec1_48 = - bitcast(targetInfo.shuffleIdx( - rewriter, loc, bitcast(vec1, int_ty(32)), addrShift48), - vecTy); - - resVec0 = select(mask0, select(mask1, vec0, shflVec0_16), - select(mask1, shflVec1_32, shflVec1_48)); - resVec1 = select(mask0, select(mask1, shflVec0_16, shflVec0_32), - select(mask1, shflVec1_48, vec1)); - } - - for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - outVals.push_back(extract_element(elemTy, resVec0, i32_val(vIdx))); - } - for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - outVals.push_back(extract_element(elemTy, resVec1, i32_val(vIdx))); - } - } - - Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, - op.getType()); - rewriter.replaceOp(op, result); - return success(); - } - -protected: - const TargetInfoBase &targetInfo; -}; - } // namespace namespace mlir::triton::AMD { @@ -276,7 +124,5 @@ void populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { patterns.add(typeConverter, benefit); - patterns.add(typeConverter, targetInfo, - benefit); } } // namespace mlir::triton::AMD