Skip to content

Commit

Permalink
Resolved comments
Browse files Browse the repository at this point in the history
  • Loading branch information
knwng committed Nov 11, 2024
1 parent 8bb6050 commit 7ff9683
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 78 deletions.
16 changes: 0 additions & 16 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1417,22 +1417,6 @@ inline Value packLLVector(Location loc, ValueRange vals,
return vec;
}

inline Operation *getSingleCombinerFromReduceOp(triton::ReduceOp op) {
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
return nullptr;
Block *block = &(*op.getCombineOp().begin());
Operation *yield = block->getTerminator();
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
reduceOp->getNumResults() != 1)
return nullptr;
if (reduceOp->getOperand(0) != block->getArgument(0) ||
reduceOp->getOperand(1) != block->getArgument(1))
return nullptr;

return reduceOp;
}

} // namespace mlir

#endif
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ def TT_ReduceOp: TT_Op<"reduce",
llvm::SmallVector<RankedTensorType> getInputTypes();
llvm::SmallVector<Type> getElementTypes();
unsigned getNumOperands();
::mlir::Operation *getSingleCombiner();
}];
}

Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,22 @@ llvm::SmallVector<Type> ReduceOp::getElementTypes() {
return getElementTypesImpl(this->getOperands());
}

::mlir::Operation *ReduceOp::getSingleCombiner() {
if (getNumOperands() != 1 || getNumResults() != 1)
return nullptr;
Block *block = &(*getCombineOp().begin());
Operation *yield = block->getTerminator();
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
reduceOp->getNumResults() != 1)
return nullptr;
if (reduceOp->getOperand(0) != block->getArgument(0) ||
reduceOp->getOperand(1) != block->getArgument(1))
return nullptr;

return reduceOp;
}

unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); }

//-- ScanOp --
Expand Down
2 changes: 2 additions & 0 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: reduce_dpp_max
tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) {
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 280, 15, 15, true : f32
Expand Down Expand Up @@ -138,6 +139,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :

#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: reduce_xor_max
tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) {
// CHECK: rocdl.ds_swizzle
// CHECK: llvm.intr.maxnum
Expand Down
5 changes: 3 additions & 2 deletions third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ ISAFamily deduceISAFamily(llvm::StringRef arch);

// Here is a partial definition of DppCtrl enums. For the complete definition,
// please check:
// https://github.com/llvm/llvm-project/blob/llvmorg-19.1.3/llvm/lib/Target/AMDGPU/SIDefines.h
enum DppCtrl : uint32_t {
// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939
enum class DppCtrl : uint32_t {
QUAD_PERM_FIRST = 0,
ROW_SHL0 = 0x100,
ROW_SHR0 = 0x110,
BCAST15 = 0x142,
Expand Down
99 changes: 61 additions & 38 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

using mlir::getSingleCombinerFromReduceOp;

using mlir::triton::AMD::DppCtrl;
namespace mlir::triton::AMD {

namespace {
Expand Down Expand Up @@ -128,35 +127,45 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis);
}

static inline Type castToInt(RewriterBase &rewriter, Location loc, Value &val,
Type valType, unsigned bits) {
unsigned originalBits = valType.getIntOrFloatBitWidth();
Type actualType = valType;
// Cast and sext values into specific-length int to meet the requirements of
// instructions like UpdateDpp or readlane if necessary.
static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc,
Value &val, Type fromType,
unsigned toBits) {
unsigned originalBits = fromType.getIntOrFloatBitWidth();
Type toType = fromType;

if (!valType.isIntOrIndex()) {
if (!fromType.isIntOrIndex()) {
val = bitcast(val, int_ty(originalBits));
actualType = int_ty(originalBits);
toType = int_ty(originalBits);
}

if (originalBits < bits) {
val = sext(int_ty(bits), val);
actualType = int_ty(bits);
if (originalBits < toBits) {
val = sext(int_ty(toBits), val);
toType = int_ty(toBits);
}

return actualType;
return toType;
}

static inline void castFromInt(RewriterBase &rewriter, Location loc, Value &val,
Type valType, unsigned bits) {
// Trunc the value to specific length and then cast it to given type if
// necessary. This function is typically used in conjunction with
// castToAndSExtInt.
static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc,
Value val, Type valType,
unsigned fromBits) {
unsigned originalBits = valType.getIntOrFloatBitWidth();
Value toVal = val;

if (originalBits < bits) {
val = trunc(int_ty(originalBits), val);
if (originalBits < fromBits) {
toVal = trunc(int_ty(originalBits), toVal);
}

if (!valType.isIntOrIndex()) {
val = bitcast(val, valType);
toVal = bitcast(toVal, valType);
}

return toVal;
}

bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
Expand All @@ -171,19 +180,19 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
return false;
}

Operation *reduxOp = getSingleCombinerFromReduceOp(op);
Operation *reduxOp = op.getSingleCombiner();
if (!reduxOp)
return false;

auto createDppReduxOp = [&](Type valType, Value &src, int dppCtrl,
int rowMask, int bankMask,
bool boundCtrl) -> Value {
auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src,
uint32_t dppCtrl, int rowMask,
int bankMask) -> Value {
// DPP has limited support for data types, so here we need to
// cast non-integer types or integer types shorter than 32 bits
// to int32, except for fp32.
Type actualType = valType;
if (!valType.isF32()) {
actualType = castToInt(rewriter, loc, src, valType, 32);
actualType = castToAndSExtInt(rewriter, loc, src, valType, 32);
}

Value dppResult =
Expand All @@ -192,12 +201,12 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
rewriter.getI32IntegerAttr(dppCtrl),
rewriter.getI32IntegerAttr(rowMask),
rewriter.getI32IntegerAttr(bankMask),
rewriter.getBoolAttr(boundCtrl))
rewriter.getBoolAttr(true))
.getRes();

if (!valType.isF32()) {
castFromInt(rewriter, loc, src, valType, 32);
castFromInt(rewriter, loc, dppResult, valType, 32);
src = truncAndCastFromInt(rewriter, loc, src, valType, 32);
dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32);
}

IRMapping mapping;
Expand All @@ -218,10 +227,13 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
some cases, the lower-level compiler could merge them into single
instruction. For example, v_mov_dpp + max => v_max_dpp.
In DPP, each row consists of 16 consecutive lanes.
So the modifier row_shr and row_bcast mean they have the same operations
in each row, so in the following instructions, we only take row 0
as an example:
For gfx9, we have 64 threads per warp. These 64 threads are arranged
into 4 rows, with each row being 16 threads. Each 16 threads are arranged
further into 4 banks, with each bank being 4 threads. Overall it's in a
(row, bank, thread) structure. When shuffling, we use row/bank mask to
indicate which row/bank to participate. Then modifier like row_shr and
row_bcast means exact data movement schemes. In the following
instructions, taking row 0 as an example:
Step 1: Right shift for 8 lanes.
lane 8-15 = redux(lane 0-7, lane 8-15)
Expand All @@ -247,27 +259,38 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
readlane.
*/

const int allRows = 0xf;
const int allBanks = 0xf;

const uint32_t dppCtrlRowShr = static_cast<uint32_t>(DppCtrl::ROW_SHR0);

// row_shr:8
buf = createDppReduxOp(valType, acc[i], 8 + DppCtrl::ROW_SHR0, 0xf, 0xf,
true);
buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr,
allRows, allBanks);

// row_shr:4
buf = createDppReduxOp(valType, buf, 4 + DppCtrl::ROW_SHR0, 0xf, 0xf, true);
buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr,
allRows, allBanks);

// row_shr:2
buf = createDppReduxOp(valType, buf, 2 + DppCtrl::ROW_SHR0, 0xf, 0xf, true);
buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr,
allRows, allBanks);

// row_shr:1
buf = createDppReduxOp(valType, buf, 1 + DppCtrl::ROW_SHR0, 0xf, 0xf, true);
buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr,
allRows, allBanks);

// row_bcast:15 row_mask:0xa
buf = createDppReduxOp(valType, buf, DppCtrl::BCAST15, 0xa, 0xf, true);
buf = createDppReduxOpWithBoundCtrl(
valType, buf, static_cast<uint32_t>(DppCtrl::BCAST15), 0xa, allBanks);

// row_bcast:31
buf = createDppReduxOp(valType, buf, DppCtrl::BCAST31, 0xf, 0xf, true);
buf = createDppReduxOpWithBoundCtrl(valType, buf,
static_cast<uint32_t>(DppCtrl::BCAST31),
allRows, allBanks);

// Similarly, we need to cast data types for readlane instruction.
Type actualType = castToInt(rewriter, loc, buf, valType, 16);
Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16);

// Get reduction result from lane 63
std::string intrinsic = "llvm.amdgcn.readlane";
Expand All @@ -276,7 +299,7 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
ValueRange{buf, i32_val(63)})
->getResult(0);

castFromInt(rewriter, loc, result, valType, 16);
result = truncAndCastFromInt(rewriter, loc, result, valType, 16);

acc[i] = result;
}
Expand Down
40 changes: 27 additions & 13 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

using mlir::triton::AMD::DppCtrl;
using mlir::triton::AMD::ISAFamily;
using mlir::triton::gpu::appendOrGetExternFuncOp;
using mlir::triton::gpu::getFunctionType;

Expand Down Expand Up @@ -154,50 +156,62 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter,
return rewriter.create<ROCDL::DsSwizzleOp>(loc, valType, val, offset);
}

auto createDppOp = [&](Value &old, Value &src, uint32_t dppCtrl,
uint32_t rowMask, uint32_t bankMask,
bool boundCtrl) {
auto createDppOpWithoutBoundCtrl = [&](Value &old, Value &src,
uint32_t dppCtrl, uint32_t rowMask,
uint32_t bankMask) {
return rewriter.create<ROCDL::DPPUpdateOp>(
loc, valType, old, src, rewriter.getI32IntegerAttr(dppCtrl),
rewriter.getI32IntegerAttr(rowMask),
rewriter.getI32IntegerAttr(bankMask),
rewriter.getBoolAttr(boundCtrl));
rewriter.getI32IntegerAttr(bankMask), rewriter.getBoolAttr(false));
};

const int allRows = 0xf;
const int allBanks = 0xf;

switch (strideInt) {
case 1: {
// quad_perm: 1, 0, 3, 2
uint32_t dppCtrl = 0;
uint32_t dppCtrl = static_cast<uint32_t>(DppCtrl::QUAD_PERM_FIRST);
std::array<uint32_t, 4> mask = {1, 0, 3, 2};
for (int i = 0; i < mask.size(); i++) {
dppCtrl |= mask[i] << (i * 2);
}
return createDppOp(val, val, dppCtrl, 0xf, 0xf, false);
return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows,
allBanks);
}
case 2: {
// quad_perm: 2, 3, 0, 1
uint32_t dppCtrl = 0;
uint32_t dppCtrl = static_cast<uint32_t>(DppCtrl::QUAD_PERM_FIRST);
std::array<uint32_t, 4> mask = {2, 3, 0, 1};
for (int i = 0; i < mask.size(); i++) {
dppCtrl |= mask[i] << (i * 2);
}
return createDppOp(val, val, dppCtrl, 0xf, 0xf, false);
return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows,
allBanks);
}
case 4: {
// row_shr:4 bank_mask: 0xa
auto ret = createDppOp(val, val, 4 + DppCtrl::ROW_SHR0, 0xf, 0xa, false)
auto ret = createDppOpWithoutBoundCtrl(
val, val, 4 + static_cast<uint32_t>(DppCtrl::ROW_SHR0),
allRows, 0xa)
.getRes();

// row_shl:4 bank_mask: 0x5
return createDppOp(ret, val, 4 + DppCtrl::ROW_SHL0, 0xf, 0x5, false);
return createDppOpWithoutBoundCtrl(
ret, val, 4 + static_cast<uint32_t>(DppCtrl::ROW_SHL0), allRows,
0x5);
}
case 8: {
// row_shr:8 bank_mask: 0xc
auto ret = createDppOp(val, val, 8 + DppCtrl::ROW_SHR0, 0xf, 0xc, false)
auto ret = createDppOpWithoutBoundCtrl(
val, val, 8 + static_cast<uint32_t>(DppCtrl::ROW_SHR0),
allRows, 0xc)
.getRes();

// row_shl:8 bank_mask: 0x3
return createDppOp(ret, val, 8 + DppCtrl::ROW_SHL0, 0xf, 0x3, false);
return createDppOpWithoutBoundCtrl(
ret, val, 8 + static_cast<uint32_t>(DppCtrl::ROW_SHL0), allRows,
0x3);
}
default:
assert(false &&
Expand Down
15 changes: 8 additions & 7 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

using mlir::triton::AMD::DppCtrl;
using mlir::triton::AMD::ISAFamily;

namespace mlir::LLVM::AMD {

const char predicatedLoad[] = "__predicated_load";
Expand All @@ -25,13 +22,17 @@ const char predicatedStoreCS[] = "__predicated_store_CS";
const char predicatedStoreWT[] = "__predicated_store_WT";

Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i,
ISAFamily isaFamily = ISAFamily::Unknown);
mlir::triton::AMD::ISAFamily isaFamily =
mlir::triton::AMD::ISAFamily::Unknown);
Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i,
ISAFamily isaFamily = ISAFamily::Unknown);
mlir::triton::AMD::ISAFamily isaFamily =
mlir::triton::AMD::ISAFamily::Unknown);
Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i,
ISAFamily isaFamily = ISAFamily::Unknown);
mlir::triton::AMD::ISAFamily isaFamily =
mlir::triton::AMD::ISAFamily::Unknown);
Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i,
ISAFamily isaFamily = ISAFamily::Unknown);
mlir::triton::AMD::ISAFamily isaFamily =
mlir::triton::AMD::ISAFamily::Unknown);

Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
int axis);
Expand Down
Loading

0 comments on commit 7ff9683

Please sign in to comment.