diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index ad56bd2d414e..eb2cde1a93f2 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -11,41 +11,41 @@ using namespace mlir; using namespace mlir::triton; // clang-format off -/*** - # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # - # WO # W1 # | # - # # # | # - # # # # # | # - # W2 # W3 # .... | # - # # # | SkipElems # - # # # # # | # - # | # - # Slice | # - # . / \ | # - # . / \ | # - # . / \| # - # # # # # # # - # # W0 # W1 # # - # # # # # - # # # # # # tensorStride # - # # W2 # W3 # --------------------------------# - # # # # # - # # # # # # # - # tensorStride # W0 # W1 # # - # ---------------------------------- # # # # - # # # # # # # - # # W2 # W3 # # - # # # # # - # # # # # # ---> lastIdx # - # . # - # . # - # . # - # # - # # - # # - # # - # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -***/ +//===--------------------------------------------------------------------------------===// +// # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +// # WO # W1 # | # +// # # # | # +// # # # # # | # +// # W2 # W3 # .... | # +// # # # | SkipElems # +// # # # # # | # +// # | # +// # Slice | # +// # . / \ | # +// # . / \ | # +// # . / \| # +// # # # # # # # +// # # W0 # W1 # # +// # # # # # +// # # # # # # tensorStride # +// # # W2 # W3 # --------------------------------# +// # # # # # +// # # # # # # # +// # tensorStride # W0 # W1 # # +// # ---------------------------------- # # # # +// # # # # # # # +// # # W2 # W3 # # +// # # # # # +// # # # # # # ---> lastIdx # +// # . # +// # . # +// # . # +// # # +// # # +// # # +// # # +// # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +//===--------------------------------------------------------------------------------===// // clang-format on namespace { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp index a8ec9276c45b..de92fa01441a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -99,12 +99,10 @@ class CallOpConversion : public OpRewritePattern { rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, pred, trueBlock, afterStore); rewriter.setInsertionPointToStart(trueBlock); - /* - | vialatile | non-tmp | gcn instr gfx94 - LLVM::StoreOp | 0 | 0 | (cg) global store - | 0 | 1 | (cs) global store nt - | 1 | 0/1 | (wt) global store sc0 sc1 - */ + // | vialatile | non-tmp | gcn instr gfx94 + // LLVM::StoreOp | 0 | 0 | (cg) global store + // | 0 | 1 | (cs) global store nt + // | 1 | 0/1 | (wt) global store sc0 sc1 bool vialatileFlag = isPredicatedStoreWT(callOp); bool nonTmpFlag = isPredicatedStoreCS(callOp); auto storeOp = rewriter.create( @@ -136,12 +134,10 @@ class CallOpConversion : public OpRewritePattern { rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, pred, trueBlock, falseBlock); rewriter.setInsertionPointToStart(trueBlock); - /* - | vialatile | non-tmp | gcn instr gfx94 - LLVM::LoadOp | 0 | 0 | (ca) global load - | 0/1 | 1 | (cg) global load nt - | 1 | 0 | (cv) flat load sc0 sc1 - */ + // | vialatile | non-tmp | gcn instr gfx94 + // LLVM::LoadOp | 0 | 0 | (ca) global load + // | 0/1 | 1 | (cg) global load nt + // | 1 | 0 | (cv) flat load sc0 sc1 bool vialatileFlag = isPredicatedLoadCV(callOp); bool nonTmpFlag = isPredicatedLoadCG(callOp); auto loadOp = rewriter.create( diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp index 03b7c56b7e6b..46d60e2c5da3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp @@ -82,17 +82,15 @@ bool isKMajor(llvm::ArrayRef order, int opIdx) { return order[0] == kdim; } -/** - * @brief checks that swizzle pattern fits into one warp block - * and block size is a multiple of swizzle size along non-K dimension - * - * @param sharedLayout - * @param opIdx operand id 0 or 1 - * @param reps number of repetitions: [non-k, k] or [batch, non-k, k] - * @param elemsPerInstr one instruction size - * @param warpsPerBlockNonK number of warps along non-k Dim - * @return bool - */ +/// Checks that swizzle pattern fits into one warp block +/// and block size is a multiple of swizzle size along non-K dimension +/// +/// \param sharedLayout +/// \param opIdx operand id 0 or 1 +/// \param reps number of repetitions: [non-k, k] or [batch, non-k, k] +/// \param elemsPerInstr one instruction size +/// \param warpsPerBlockNonK number of warps along non-k Dim +/// \returns bool bool isSwizzlePatternFitsIntoBlock(const SharedEncodingAttr sharedLayout, int opIdx, const ArrayRef reps, const ArrayRef elemsPerInstr, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h index b2c6759fcb59..1b0e3b2df003 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h @@ -13,18 +13,16 @@ Value getWarpIdInBlock(ConversionPatternRewriter &rewriter, Location loc, bool isSwizzled(gpu::SharedEncodingAttr layout); -/** - * @brief swizzling tensor element indexes according pattern encoded in - * SharedEncodingAttr - * - * @param rewriter - * @param loc - * @param row row of target tensor element related to the start of smemObj - * @param col col of target tensor element related to the start of smemObj - * @param smemObj shared memory object, contains info about tensor in LDS - * @param attr layout attribute, contains swizzling info - * @return swizzled row, col indexes in tensor notation - */ +/// Swizzling tensor element indexes according pattern encoded in +/// SharedEncodingAttr +/// +/// \param rewriter +/// \param loc +/// \param row row of target tensor element related to the start of smemObj +/// \param col col of target tensor element related to the start of smemObj +/// \param smemObj shared memory object, contains info about tensor in LDS +/// \param attr layout attribute, contains swizzling info +/// \returns swizzled row, col indexes in tensor notation std::pair swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row, Value col, SharedMemoryObject smemObj, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 460bb37ab583..e55d87cb9434 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -33,43 +33,41 @@ using ::mlir::triton::gpu::SharedEncodingAttr; namespace SharedToDotOperandMFMA { -/** - * @brief This function maps particular load of mfma dot operand to element - * indexes(row, col) - * - * Whole tensor is broken into "blocks" of warps along "non-K" axis. - * One block could be processed by multiple warps. - * One warp works on a piece of tensor size elemsPerInstr[0] x K. - * Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x - * elemsPerInstr[1]. - * - * Total offset of element is a sum of following values: - * 1. Offset of warp-block in tensor - * 2. Offset of warp inside one warp-block - * 3. Offset of tile in one warp - * 4. Offset of one lane data in a tile - * 5. Offset of particular element of tensor processed by one lane - * - * This function computes these offsets for axies independently - * Note that this function returns the offsets of elements in the first - * warp-block. The offsets of elements in later warp-blocks can be computed - * by adding a constant stride to the xor-ed offsets of elements in the - * first warp-block. - * - * @param rewriter - * @param loc - * @param elemsPerInstr operand tile shape consumed by one MFMA instruction - * @param warpId id component of 2d warp grid along non-K axis - * @param laneId lane id in warp [0..63] - * @param numOfElems number of elements accessed by thread per repetition - * @param reps number of instructions repetition to fully cover dot operand - * @param smemStrides strides in LDS tensor - * @param loadVecSize number of elements loaded by one operation - * @param iNonKDim non-K dimension size of one MFMA instruction - * @param iKDim K dimension size of one MFMA instruction - * @return vector (i-th element corresponds to i-th load instruction) of - * 2-element vectors(tensor row and col). - */ +/// This function maps particular load of mfma dot operand to element +/// indexes(row, col) +/// +/// Whole tensor is broken into "blocks" of warps along "non-K" axis. +/// One block could be processed by multiple warps. +/// One warp works on a piece of tensor size elemsPerInstr[0] x K. +/// Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x +/// elemsPerInstr[1]. +/// +/// Total offset of element is a sum of following values: +/// 1. Offset of warp-block in tensor +/// 2. Offset of warp inside one warp-block +/// 3. Offset of tile in one warp +/// 4. Offset of one lane data in a tile +/// 5. Offset of particular element of tensor processed by one lane +/// +/// This function computes these offsets for axies independently +/// Note that this function returns the offsets of elements in the first +/// warp-block. The offsets of elements in later warp-blocks can be computed +/// by adding a constant stride to the xor-ed offsets of elements in the +/// first warp-block. +/// +/// \param rewriter +/// \param loc +/// \param elemsPerInstr operand tile shape consumed by one MFMA instruction +/// \param warpId id component of 2d warp grid along non-K axis +/// \param laneId lane id in warp [0..63] +/// \param numOfElems number of elements accessed by thread per repetition +/// \param reps number of instructions repetition to fully cover dot operand +/// \param smemStrides strides in LDS tensor +/// \param loadVecSize number of elements loaded by one operation +/// \param iNonKDim non-K dimension size of one MFMA instruction +/// \param iKDim K dimension size of one MFMA instruction +/// \returns vector (i-th element corresponds to i-th load instruction) of +/// 2-element vectors(tensor row and col). llvm::SmallVector> computeTensorElemMappingInBlock( ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value warpId, Value laneId, @@ -127,17 +125,18 @@ bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) { return srcEncoding.getMaxPhase() > 1; } -// Computes offsets for operand B or transposed operand A -// @param rewriter -// @param loc -// @param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA -// instruction -// @param warpId warp id for the "non K" axis -// @param laneId lane id in warp [0..63] -// @param warpsPerBlock number of warps per horizontal axis -// @param numOfElems number of elements accessed by threads per repetition -// @param reps number of instructions repretition to fully cover dot operand -// @param cSwizzleOffset +/// Computes offsets for operand B or transposed operand A +/// +/// \param rewriter +/// \param loc +/// \param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA +/// instruction +/// \param warpId warp id for the "non K" axis +/// \param laneId lane id in warp [0..63] +/// \param warpsPerBlock number of warps per horizontal axis +/// \param numOfElems number of elements accessed by threads per repetition +/// \param reps number of instructions repretition to fully cover dot operand +/// \param cSwizzleOffset llvm::SmallVector fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value warpId, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index 7f037b89b854..8d5bc669e1eb 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -33,39 +33,37 @@ using ::mlir::triton::gpu::SharedEncodingAttr; namespace SharedToDotOperandWMMA { -/** - * @brief Following functions maps particular load of wmma dot operand to - * element indexes(row, col). For each WMMA generation separate function is - * used. - * - * Whole tensor is broken into "blocks" of warps along "non-K" axis. - * One block could be processed by multiple warps. - * One warp works on a piece of tensor size elemsPerInstr[0] x K. - * Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x - * elemsPerInstr[1]. - * - * Total offset of element is a sum of following values: - * 1. Offset of warp block in tensor - * 2. Offset of warp inside one warp block - * 3. Offset of tile in one warp - * 4. Offset of one lane data in a tile - * 5. Offset of particular element of tensor processed by one lane - * - * This function computes these offsets for axes independently - * - * @param rewriter - * @param loc - * @param elemsPerInstr operand tile shape consumed by one WMMA instruction - * @param warpId id component of 2d warp grid along non-K axis - * @param laneId lane id in warp [0..63] - * @param numOfElems number of elements accessed by thread per repetition - * @param reps number of instructions repetition to fully cover dot operand - * @param smemStrides strides in LDS tensor - * @param loadVecSize number of elements loaded by one operation - * @param iNonKDim non-K dimension of dot operand - * @return vector (i-th element corresponds to i-th load instruction) of - * 2-element vectors(tensor row and col). - */ +/// Following functions maps particular load of wmma dot operand to +/// element indexes(row, col). For each WMMA generation separate function is +/// used. +/// +/// Whole tensor is broken into "blocks" of warps along "non-K" axis. +/// One block could be processed by multiple warps. +/// One warp works on a piece of tensor size elemsPerInstr[0] x K. +/// Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x +/// elemsPerInstr[1]. +/// +/// Total offset of element is a sum of following values: +/// 1. Offset of warp block in tensor +/// 2. Offset of warp inside one warp block +/// 3. Offset of tile in one warp +/// 4. Offset of one lane data in a tile +/// 5. Offset of particular element of tensor processed by one lane +/// +/// This function computes these offsets for axes independently +/// +/// \param rewriter +/// \param loc +/// \param elemsPerInstr operand tile shape consumed by one WMMA instruction +/// \param warpId id component of 2d warp grid along non-K axis +/// \param laneId lane id in warp [0..63] +/// \param numOfElems number of elements accessed by thread per repetition +/// \param reps number of instructions repetition to fully cover dot operand +/// \param smemStrides strides in LDS tensor +/// \param loadVecSize number of elements loaded by one operation +/// \param iNonKDim non-K dimension of dot operand +/// \returns vector (i-th element corresponds to i-th load instruction) of +/// 2-element vectors(tensor row and col). llvm::SmallVector> computeTensorElemMappingInBlockWmma1( ConversionPatternRewriter &rewriter, Location loc, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index 7baa9485b306..bbacde54b041 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -41,9 +41,7 @@ struct DecomposeUnsupportedAMDConversions triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut); - /* -------------------------------- */ // Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op` - /* -------------------------------- */ mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { OpBuilder builder(cvtOp); auto srcType = cvtOp.getSrc().getType(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index 12e6100380c9..54e3c6ac8527 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -274,11 +274,9 @@ struct DotOpMFMAConversionHelper { return success(); } - /** - * @brief extract vector from rawElems based on kWidth and kBase - * rawElems is a vector of kWidth elements. We need to prepare vector(s) of - * kBase elements for each mfma instruction - */ + /// Extract vector from rawElems based on kWidth and kBase + /// rawElems is a vector of kWidth elements. We need to prepare vector(s) of + /// kBase elements for each mfma instruction SmallVector extractOperands(Value rawElems, int kWidth, int kBase, Type type) const { int kpack = kWidth / kBase; @@ -311,10 +309,8 @@ struct DotOpMFMAConversionHelper { return results; } - /** - * @brief Converts dot operand structure to value table and converts types - * appropriate for mfma instructions - */ + /// Converts dot operand structure to value table and converts types + /// appropriate for mfma instructions SmallVector getValuesFromDotOperandLayoutStruct(Value value, int batch, int n0, int n1, int kWidth, int kBase, Type type) const { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index b85497dcb4bf..716a93865ddd 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -21,8 +21,14 @@ typedef std::function(Location, ConversionPatternRewriter &, ConverterT; namespace { -// ROCM utility functions for data type conversion -/* ----- FP8E5M2 ------ */ +//===-------------------------------------------===// +/// ROCM utility functions for data type conversion +//===-------------------------------------------===// + +//===----------------===// +/// FP8E5M2 +//===----------------===// + // This data-type is the standard FP8E5M2 format // NVIDIA GPU supports it natively but we don't have hardware native // support on MI300. @@ -221,6 +227,7 @@ Fp8E4M3FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, assert(v.size() == 2); return cvtFp8ToFp32(loc, rewriter, v[0], v[1], "fp8"); } + // Depend on whether we focus more on performance, we may skip // the processing of submornal values static Value Fp16_to_Fp8E5M2FNUZ_oneValue(Location loc, @@ -537,7 +544,9 @@ static SmallVector Bf16_to_Fp8E5M2(Location loc, extract_element(i8_ty, fp8x4Vec, i32_val(3))}; } -// ROCM type conversion between fp8 and bf16 +//===-----------------------------------------===// +/// ROCM type conversion between fp8 and bf16 +//===-----------------------------------------===// // fp8e4m3fn to bf16 static SmallVector Fp8E4M3FN_to_Bf16(Location loc, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h index 701f03b129f3..6b902b303c81 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h @@ -11,27 +11,24 @@ int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp op); std::vector> factorizePowerOf2(int n, int rank); -/** - * @brief Copy given layout with different warpsPerCTA parameter - * @param layout original layout - * @param warpsPerCTA new warpsPerCTA - * @return create layout - */ +/// Copy given layout with different warpsPerCTA parameter +/// +/// \param layout original layout +/// \param warpsPerCTA new warpsPerCTA +/// \returns create layout Attribute createTmpLayout(Attribute layout, ArrayRef warpsPerCTA); -/** - * Creates two chained convert layout operations - * - * %1 = cvtOp %0 (srcLayout -> dstLayout) // original operation - * -> - * %2 = cvtOp %0 (srcLayout -> tmpLayout) // .first - * %3 = cvtOp %2 (tmpLayout -> dstLayout) // .second - * - * @param builder - * @param cvtOp original operation - * @param tmpLayout - * @return pair of created operations - */ +/// Creates two chained convert layout operations +/// +/// %1 = cvtOp %0 (srcLayout -> dstLayout) // original operation +/// -> +/// %2 = cvtOp %0 (srcLayout -> tmpLayout) // .first +/// %3 = cvtOp %2 (tmpLayout -> dstLayout) // .second +/// +/// \param builder +/// \param cvtOp original operation +/// \param tmpLayout +/// \returns pair of created operations std::pair createNewConvertOps(OpBuilder &builder, triton::gpu::ConvertLayoutOp &cvtOp, Attribute tmpLayout); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index cca1714f6581..9a0098790057 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -228,45 +228,43 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, Value buf; auto valType = acc[i].getType(); - /* - Here's the implementation of full-wavefront reduction using dpp. - https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ - - Each step has a v_mov_dpp instruction following the redux op. In - some cases, the lower-level compiler could merge them into single - instruction. For example, v_mov_dpp + max => v_max_dpp. - - For gfx9, we have 64 threads per warp. These 64 threads are arranged - into 4 rows, with each row being 16 threads. Each 16 threads are arranged - further into 4 banks, with each bank being 4 threads. Overall it's in a - (row, bank, thread) structure. When shuffling, we use row/bank mask to - indicate which row/bank to participate. Then modifier like row_shr and - row_bcast means exact data movement schemes. In the following - instructions, taking row 0 as an example: - - Step 1: Right shift for 8 lanes. - lane 8-15 = redux(lane 0-7, lane 8-15) - - Step 2: Right shift for 4 lanes. - lane 12-15 = redux(lane 8-11, lane 12-15) - - Step 3: Right shift for 2 lanes. - lane 14-15 = redux(lane 12-13, lane 14-15) - - Step 4: Right shift for 1 lane. - lane 15 = redux(lane 14, lane 15) - - Step 5: Broadcast lane 15 of each row to all the lanes of its next row. - lane 16-31 = redux(lane 15, lane 16-31) - - Step 6: Broadcast lane 31 to lane 32-63. - lane 32-63 = redux(lane 31, lane 32-63) - - Now the reduction result is stored in lane 63. - - Step 7: Read the reduction result from lane 63 and broadcast with - readlane. - */ + // Here's the implementation of full-wavefront reduction using dpp. + // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ + // + // Each step has a v_mov_dpp instruction following the redux op. In + // some cases, the lower-level compiler could merge them into single + // instruction. For example, v_mov_dpp + max => v_max_dpp. + // + // For gfx9, we have 64 threads per warp. These 64 threads are arranged + // into 4 rows, with each row being 16 threads. Each 16 threads are arranged + // further into 4 banks, with each bank being 4 threads. Overall it's in a + // (row, bank, thread) structure. When shuffling, we use row/bank mask to + // indicate which row/bank to participate. Then modifier like row_shr and + // row_bcast means exact data movement schemes. In the following + // instructions, taking row 0 as an example: + // + // Step 1: Right shift for 8 lanes. + // lane 8-15 = redux(lane 0-7, lane 8-15) + // + // Step 2: Right shift for 4 lanes. + // lane 12-15 = redux(lane 8-11, lane 12-15) + // + // Step 3: Right shift for 2 lanes. + // lane 14-15 = redux(lane 12-13, lane 14-15) + // + // Step 4: Right shift for 1 lane. + // lane 15 = redux(lane 14, lane 15) + // + // Step 5: Broadcast lane 15 of each row to all the lanes of its next row. + // lane 16-31 = redux(lane 15, lane 16-31) + // + // Step 6: Broadcast lane 31 to lane 32-63. + // lane 32-63 = redux(lane 31, lane 32-63) + // + // Now the reduction result is stored in lane 63. + // + // Step 7: Read the reduction result from lane 63 and broadcast with + // readlane. const int allRows = 0xf; const int allBanks = 0xf; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index e8d3607c46ab..f9cd0f14382e 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -265,23 +265,23 @@ OperandTypesVector getOperandTypesForWmmaOp(PatternRewriter &rewriter, return selectMatrixCoreOperandTypes(dot, applicableTypes); } -/** - * @brief Convert layout and cast element type of a given tensor - * - * If old element type is different from new element type, this function - * creates two new operations: - * 1. %converted_value = layout_convert %value, newEncoding - * 2. %casted_value = cast(fext, ftrunc, etc.) %value, newElemType - * - * If old element type is same as new element type, this function creates only - * one operation: %converted_value = layout_convert %value, newEncoding - * - * @param rewriter - * @param value original tensor value, which we need to convert and cast - * @param newEncoding new encoding for the tenosr - * @param newElemType new element type for the tensor - * @return converted and optionaly casted tensor value - */ +//===---------------------------------------------------------------------===// +// @brief Convert layout and cast element type of a given tensor +// +// If old element type is different from new element type, this function +// creates two new operations: +// 1. %converted_value = layout_convert %value, newEncoding +// 2. %casted_value = cast(fext, ftrunc, etc.) %value, newElemType +// +// If old element type is same as new element type, this function creates only +// one operation: %converted_value = layout_convert %value, newEncoding +// +// @param rewriter +// @param value original tensor value, which we need to convert and cast +// @param newEncoding new encoding for the tenosr +// @param newElemType new element type for the tensor +// @return converted and optionaly casted tensor value +//===---------------------------------------------------------------------===// Value convertAndCastTensor(PatternRewriter &rewriter, Value value, Attribute newEncoding, Type newElemType) { assert(newElemType.isIntOrFloat()); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 0837f16dcf7c..eabb2ad66bfd 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -280,48 +280,48 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { } } -/** - * Sched-load optimization for matmul kernels with large tile sizes - * The basic idea of sched-load optimization is to sink the 2nd tt.load - * after local_load so that global_load instructions can be interleaved with - * mfma's. This can help hide the issue latency of global_load instructions - * and improve performance on MI300X. - * - * It's assumed that the IR before this optimization has the following - * structure: - * ```mlir - * scf.for .. - * { - * tileA = tt.load a_ptr - * tileB = tt.load b_ptr - * opA = local_load bufferA - * opB = local_load bufferB - * res = tt.dot opA, opB - * local_store tileA, bufferA - * local_store tileB, bufferB - * } - * ``` - * After this optimization, the IR is transformed to - * ```mlir - * scf.for .. - * { - * tileA = tt.load a_ptr - * opA = local_load bufferA - * opB = local_load bufferB - * tileB = tt.load b_ptr <-- 2nd tt.load is sinked here - * res = tt.dot opA, opB - * local_store tileA, bufferA - * local_store tileB, bufferB - * } - * ``` - * For now, we don't have a perfect hueristic about when should this - * optimization be applied. Therefore, we implement a simple hueristic that - * this is applied when the tile size of A and B are large enough, i.e. - * nonKDim >= 128 and kDim >= 64. And also this is only applied for typical - * matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We - * are experimenting how to better control instruction scheduling and enable - * such optimizations. - */ +//===-------------------------------------------------------------------===// +// Sched-load optimization for matmul kernels with large tile sizes +// The basic idea of sched-load optimization is to sink the 2nd tt.load +// after local_load so that global_load instructions can be interleaved with +// mfma's. This can help hide the issue latency of global_load instructions +// and improve performance on MI300X. +// +// It's assumed that the IR before this optimization has the following +// structure: +// ```mlir +// scf.for .. +// { +// tileA = tt.load a_ptr +// tileB = tt.load b_ptr +// opA = local_load bufferA +// opB = local_load bufferB +// res = tt.dot opA, opB +// local_store tileA, bufferA +// local_store tileB, bufferB +// } +// ``` +// After this optimization, the IR is transformed to +// ```mlir +// scf.for .. +// { +// tileA = tt.load a_ptr +// opA = local_load bufferA +// opB = local_load bufferB +// tileB = tt.load b_ptr <-- 2nd tt.load is sinked here +// res = tt.dot opA, opB +// local_store tileA, bufferA +// local_store tileB, bufferB +// } +// ``` +// For now, we don't have a perfect hueristic about when should this +// optimization be applied. Therefore, we implement a simple hueristic that +// this is applied when the tile size of A and B are large enough, i.e. +// nonKDim >= 128 and kDim >= 64. And also this is only applied for typical +// matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We +// are experimenting how to better control instruction scheduling and enable +// such optimizations. +//===-------------------------------------------------------------------===// static void sinkSecondLoad(triton::FuncOp funcOp) { funcOp.walk([&](scf::ForOp forOp) -> void { SetVector loadOps;