Skip to content

Commit

Permalink
[AMD] NFC: Unified comment style (#5248)
Browse files Browse the repository at this point in the history
Script:

egrep -nrI --exclude-dir "backend" "^\s*/\*+" third_party/amd
  • Loading branch information
knwng authored Nov 27, 2024
1 parent 9e508a4 commit e7a0561
Show file tree
Hide file tree
Showing 13 changed files with 269 additions and 282 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,41 @@ using namespace mlir;
using namespace mlir::triton;

// clang-format off
/***
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# WO # W1 # | #
# # # | #
# # # # # | #
# W2 # W3 # .... | #
# # # | SkipElems #
# # # # # | #
# | #
# Slice | #
# . / \ | #
# . / \ | #
# . / \| #
# # # # # # #
# # W0 # W1 # #
# # # # #
# # # # # # tensorStride #
# # W2 # W3 # --------------------------------#
# # # # #
# # # # # # #
# tensorStride # W0 # W1 # #
# ---------------------------------- # # # #
# # # # # # #
# # W2 # W3 # #
# # # # #
# # # # # # ---> lastIdx #
# . #
# . #
# . #
# #
# #
# #
# #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
***/
//===--------------------------------------------------------------------------------===//
// # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
// # WO # W1 # | #
// # # # | #
// # # # # # | #
// # W2 # W3 # .... | #
// # # # | SkipElems #
// # # # # # | #
// # | #
// # Slice | #
// # . / \ | #
// # . / \ | #
// # . / \| #
// # # # # # # #
// # # W0 # W1 # #
// # # # # #
// # # # # # # tensorStride #
// # # W2 # W3 # --------------------------------#
// # # # # #
// # # # # # # #
// # tensorStride # W0 # W1 # #
// # ---------------------------------- # # # #
// # # # # # # #
// # # W2 # W3 # #
// # # # # #
// # # # # # # ---> lastIdx #
// # . #
// # . #
// # . #
// # #
// # #
// # #
// # #
// # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
//===--------------------------------------------------------------------------------===//
// clang-format on

namespace {
Expand Down
20 changes: 8 additions & 12 deletions third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,10 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, pred, trueBlock, afterStore);
rewriter.setInsertionPointToStart(trueBlock);
/*
| vialatile | non-tmp | gcn instr gfx94
LLVM::StoreOp | 0 | 0 | (cg) global store
| 0 | 1 | (cs) global store nt
| 1 | 0/1 | (wt) global store sc0 sc1
*/
// | vialatile | non-tmp | gcn instr gfx94
// LLVM::StoreOp | 0 | 0 | (cg) global store
// | 0 | 1 | (cs) global store nt
// | 1 | 0/1 | (wt) global store sc0 sc1
bool vialatileFlag = isPredicatedStoreWT(callOp);
bool nonTmpFlag = isPredicatedStoreCS(callOp);
auto storeOp = rewriter.create<LLVM::StoreOp>(
Expand Down Expand Up @@ -136,12 +134,10 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, pred, trueBlock, falseBlock);
rewriter.setInsertionPointToStart(trueBlock);
/*
| vialatile | non-tmp | gcn instr gfx94
LLVM::LoadOp | 0 | 0 | (ca) global load
| 0/1 | 1 | (cg) global load nt
| 1 | 0 | (cv) flat load sc0 sc1
*/
// | vialatile | non-tmp | gcn instr gfx94
// LLVM::LoadOp | 0 | 0 | (ca) global load
// | 0/1 | 1 | (cg) global load nt
// | 1 | 0 | (cv) flat load sc0 sc1
bool vialatileFlag = isPredicatedLoadCV(callOp);
bool nonTmpFlag = isPredicatedLoadCG(callOp);
auto loadOp = rewriter.create<LLVM::LoadOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,15 @@ bool isKMajor(llvm::ArrayRef<unsigned> order, int opIdx) {
return order[0] == kdim;
}

/**
* @brief checks that swizzle pattern fits into one warp block
* and block size is a multiple of swizzle size along non-K dimension
*
* @param sharedLayout
* @param opIdx operand id 0 or 1
* @param reps number of repetitions: [non-k, k] or [batch, non-k, k]
* @param elemsPerInstr one instruction size
* @param warpsPerBlockNonK number of warps along non-k Dim
* @return bool
*/
/// Checks that swizzle pattern fits into one warp block
/// and block size is a multiple of swizzle size along non-K dimension
///
/// \param sharedLayout
/// \param opIdx operand id 0 or 1
/// \param reps number of repetitions: [non-k, k] or [batch, non-k, k]
/// \param elemsPerInstr one instruction size
/// \param warpsPerBlockNonK number of warps along non-k Dim
/// \returns bool
bool isSwizzlePatternFitsIntoBlock(const SharedEncodingAttr sharedLayout,
int opIdx, const ArrayRef<int64_t> reps,
const ArrayRef<int64_t> elemsPerInstr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,16 @@ Value getWarpIdInBlock(ConversionPatternRewriter &rewriter, Location loc,

bool isSwizzled(gpu::SharedEncodingAttr layout);

/**
* @brief swizzling tensor element indexes according pattern encoded in
* SharedEncodingAttr
*
* @param rewriter
* @param loc
* @param row row of target tensor element related to the start of smemObj
* @param col col of target tensor element related to the start of smemObj
* @param smemObj shared memory object, contains info about tensor in LDS
* @param attr layout attribute, contains swizzling info
* @return swizzled row, col indexes in tensor notation
*/
/// Swizzling tensor element indexes according pattern encoded in
/// SharedEncodingAttr
///
/// \param rewriter
/// \param loc
/// \param row row of target tensor element related to the start of smemObj
/// \param col col of target tensor element related to the start of smemObj
/// \param smemObj shared memory object, contains info about tensor in LDS
/// \param attr layout attribute, contains swizzling info
/// \returns swizzled row, col indexes in tensor notation
std::pair<mlir::Value, mlir::Value>
swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
Value col, SharedMemoryObject smemObj,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,43 +33,41 @@ using ::mlir::triton::gpu::SharedEncodingAttr;

namespace SharedToDotOperandMFMA {

/**
* @brief This function maps particular load of mfma dot operand to element
* indexes(row, col)
*
* Whole tensor is broken into "blocks" of warps along "non-K" axis.
* One block could be processed by multiple warps.
* One warp works on a piece of tensor size elemsPerInstr[0] x K.
* Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x
* elemsPerInstr[1].
*
* Total offset of element is a sum of following values:
* 1. Offset of warp-block in tensor
* 2. Offset of warp inside one warp-block
* 3. Offset of tile in one warp
* 4. Offset of one lane data in a tile
* 5. Offset of particular element of tensor processed by one lane
*
* This function computes these offsets for axies independently
* Note that this function returns the offsets of elements in the first
* warp-block. The offsets of elements in later warp-blocks can be computed
* by adding a constant stride to the xor-ed offsets of elements in the
* first warp-block.
*
* @param rewriter
* @param loc
* @param elemsPerInstr operand tile shape consumed by one MFMA instruction
* @param warpId id component of 2d warp grid along non-K axis
* @param laneId lane id in warp [0..63]
* @param numOfElems number of elements accessed by thread per repetition
* @param reps number of instructions repetition to fully cover dot operand
* @param smemStrides strides in LDS tensor
* @param loadVecSize number of elements loaded by one operation
* @param iNonKDim non-K dimension size of one MFMA instruction
* @param iKDim K dimension size of one MFMA instruction
* @return vector (i-th element corresponds to i-th load instruction) of
* 2-element vectors(tensor row and col).
*/
/// This function maps particular load of mfma dot operand to element
/// indexes(row, col)
///
/// Whole tensor is broken into "blocks" of warps along "non-K" axis.
/// One block could be processed by multiple warps.
/// One warp works on a piece of tensor size elemsPerInstr[0] x K.
/// Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x
/// elemsPerInstr[1].
///
/// Total offset of element is a sum of following values:
/// 1. Offset of warp-block in tensor
/// 2. Offset of warp inside one warp-block
/// 3. Offset of tile in one warp
/// 4. Offset of one lane data in a tile
/// 5. Offset of particular element of tensor processed by one lane
///
/// This function computes these offsets for axies independently
/// Note that this function returns the offsets of elements in the first
/// warp-block. The offsets of elements in later warp-blocks can be computed
/// by adding a constant stride to the xor-ed offsets of elements in the
/// first warp-block.
///
/// \param rewriter
/// \param loc
/// \param elemsPerInstr operand tile shape consumed by one MFMA instruction
/// \param warpId id component of 2d warp grid along non-K axis
/// \param laneId lane id in warp [0..63]
/// \param numOfElems number of elements accessed by thread per repetition
/// \param reps number of instructions repetition to fully cover dot operand
/// \param smemStrides strides in LDS tensor
/// \param loadVecSize number of elements loaded by one operation
/// \param iNonKDim non-K dimension size of one MFMA instruction
/// \param iKDim K dimension size of one MFMA instruction
/// \returns vector (i-th element corresponds to i-th load instruction) of
/// 2-element vectors(tensor row and col).
llvm::SmallVector<llvm::SmallVector<Value>> computeTensorElemMappingInBlock(
ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value warpId, Value laneId,
Expand Down Expand Up @@ -127,17 +125,18 @@ bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) {
return srcEncoding.getMaxPhase() > 1;
}

// Computes offsets for operand B or transposed operand A
// @param rewriter
// @param loc
// @param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA
// instruction
// @param warpId warp id for the "non K" axis
// @param laneId lane id in warp [0..63]
// @param warpsPerBlock number of warps per horizontal axis
// @param numOfElems number of elements accessed by threads per repetition
// @param reps number of instructions repretition to fully cover dot operand
// @param cSwizzleOffset
/// Computes offsets for operand B or transposed operand A
///
/// \param rewriter
/// \param loc
/// \param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA
/// instruction
/// \param warpId warp id for the "non K" axis
/// \param laneId lane id in warp [0..63]
/// \param warpsPerBlock number of warps per horizontal axis
/// \param numOfElems number of elements accessed by threads per repetition
/// \param reps number of instructions repretition to fully cover dot operand
/// \param cSwizzleOffset
llvm::SmallVector<Value>
fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value warpId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,39 +33,37 @@ using ::mlir::triton::gpu::SharedEncodingAttr;

namespace SharedToDotOperandWMMA {

/**
* @brief Following functions maps particular load of wmma dot operand to
* element indexes(row, col). For each WMMA generation separate function is
* used.
*
* Whole tensor is broken into "blocks" of warps along "non-K" axis.
* One block could be processed by multiple warps.
* One warp works on a piece of tensor size elemsPerInstr[0] x K.
* Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x
* elemsPerInstr[1].
*
* Total offset of element is a sum of following values:
* 1. Offset of warp block in tensor
* 2. Offset of warp inside one warp block
* 3. Offset of tile in one warp
* 4. Offset of one lane data in a tile
* 5. Offset of particular element of tensor processed by one lane
*
* This function computes these offsets for axes independently
*
* @param rewriter
* @param loc
* @param elemsPerInstr operand tile shape consumed by one WMMA instruction
* @param warpId id component of 2d warp grid along non-K axis
* @param laneId lane id in warp [0..63]
* @param numOfElems number of elements accessed by thread per repetition
* @param reps number of instructions repetition to fully cover dot operand
* @param smemStrides strides in LDS tensor
* @param loadVecSize number of elements loaded by one operation
* @param iNonKDim non-K dimension of dot operand
* @return vector (i-th element corresponds to i-th load instruction) of
* 2-element vectors(tensor row and col).
*/
/// Following functions maps particular load of wmma dot operand to
/// element indexes(row, col). For each WMMA generation separate function is
/// used.
///
/// Whole tensor is broken into "blocks" of warps along "non-K" axis.
/// One block could be processed by multiple warps.
/// One warp works on a piece of tensor size elemsPerInstr[0] x K.
/// Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x
/// elemsPerInstr[1].
///
/// Total offset of element is a sum of following values:
/// 1. Offset of warp block in tensor
/// 2. Offset of warp inside one warp block
/// 3. Offset of tile in one warp
/// 4. Offset of one lane data in a tile
/// 5. Offset of particular element of tensor processed by one lane
///
/// This function computes these offsets for axes independently
///
/// \param rewriter
/// \param loc
/// \param elemsPerInstr operand tile shape consumed by one WMMA instruction
/// \param warpId id component of 2d warp grid along non-K axis
/// \param laneId lane id in warp [0..63]
/// \param numOfElems number of elements accessed by thread per repetition
/// \param reps number of instructions repetition to fully cover dot operand
/// \param smemStrides strides in LDS tensor
/// \param loadVecSize number of elements loaded by one operation
/// \param iNonKDim non-K dimension of dot operand
/// \returns vector (i-th element corresponds to i-th load instruction) of
/// 2-element vectors(tensor row and col).
llvm::SmallVector<llvm::SmallVector<Value>>
computeTensorElemMappingInBlockWmma1(
ConversionPatternRewriter &rewriter, Location loc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ struct DecomposeUnsupportedAMDConversions

triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut);

/* -------------------------------- */
// Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op`
/* -------------------------------- */
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getSrc().getType();
Expand Down
Loading

0 comments on commit e7a0561

Please sign in to comment.