Skip to content

Commit

Permalink
[AMD] Support warp-level reduction with DPP (triton-lang#5019)
Browse files Browse the repository at this point in the history
This commit adds support for warp-level reduction
with DPP instructions, which can improve performance.

See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
  • Loading branch information
knwng authored Nov 14, 2024
1 parent f737843 commit 21119e3
Show file tree
Hide file tree
Showing 8 changed files with 404 additions and 40 deletions.
4 changes: 4 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,10 @@ def TT_ReduceOp: TT_Op<"reduce",
llvm::SmallVector<RankedTensorType> getInputTypes();
llvm::SmallVector<Type> getElementTypes();
unsigned getNumOperands();

// Returns the CombineOp iff this ReduceOp's region contains only
// one CombineOp other than the return, or nullptr if not applicable.
::mlir::Operation *getSingleCombiner();
}];
}

Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,22 @@ llvm::SmallVector<Type> ReduceOp::getElementTypes() {
return getElementTypesImpl(this->getOperands());
}

::mlir::Operation *ReduceOp::getSingleCombiner() {
if (getNumOperands() != 1 || getNumResults() != 1)
return nullptr;
Block *block = &(*getCombineOp().begin());
Operation *yield = block->getTerminator();
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
reduceOp->getNumResults() != 1)
return nullptr;
if (reduceOp->getOperand(0) != block->getArgument(0) ||
reduceOp->getOperand(1) != block->getArgument(1))
return nullptr;

return reduceOp;
}

unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); }

//-- ScanOp --
Expand Down
76 changes: 76 additions & 0 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,79 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: reduce_dpp_max
tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) {
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 280, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 276, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 274, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 273, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 322, 10, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 323, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK: llvm.amdgcn.readlane
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.maxnumf %arg1, %arg2 : f32
tt.reduce.return %1 : f32
}) : (tensor<64xf32, #blocked3>) -> f32
tt.return
}
}

// -----

#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: reduce_xor_max
tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) {
// CHECK: rocdl.ds_swizzle
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 280, 15, 12, false : i32
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 264, 15, 3, false : i32
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 276, 15, 10, false : i32
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 260, 15, 5, false : i32
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 78, 15, 15, false : i32
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 177, 15, 15, false : i32
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.maxnumf %arg1, %arg2 : f32
tt.reduce.return %1 : f32
}) : (tensor<32xf32, #blocked4>) -> f32
tt.return
}
}
11 changes: 11 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ enum class ISAFamily {
// Deduces the corresponding ISA family for the given target gfx |arch|.
ISAFamily deduceISAFamily(llvm::StringRef arch);

// Here is a partial definition of DppCtrl enums. For the complete definition,
// please check:
// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939
enum class DppCtrl : uint32_t {
QUAD_PERM_FIRST = 0,
ROW_SHL0 = 0x100,
ROW_SHR0 = 0x110,
BCAST15 = 0x142,
BCAST31 = 0x143
};

} // namespace mlir::triton::AMD

#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H
184 changes: 179 additions & 5 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

using mlir::triton::AMD::DppCtrl;
namespace mlir::triton::AMD {

namespace {
Expand Down Expand Up @@ -103,34 +104,207 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,

Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val,
int i) const {
return LLVM::AMD::shuffleXor(loc, rewriter, val, i);
return LLVM::AMD::shuffleXor(loc, rewriter, val, i, getISAFamily());
}

Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val,
int i) const {
return LLVM::AMD::shuffleUp(loc, rewriter, val, i);
return LLVM::AMD::shuffleUp(loc, rewriter, val, i, getISAFamily());
}

Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
int i) const {
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i);
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily());
}

Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
Value i) const {
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i);
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily());
}

Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
ModuleOp moduleOp, int axis) const {
return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis);
}

// Cast and sext values into specific-length int to meet the requirements of
// instructions like UpdateDpp or readlane if necessary.
static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc,
Value &val, Type fromType,
unsigned toBits) {
unsigned originalBits = fromType.getIntOrFloatBitWidth();
Type toType = fromType;

if (!fromType.isIntOrIndex()) {
val = bitcast(val, int_ty(originalBits));
toType = int_ty(originalBits);
}

if (originalBits < toBits) {
val = sext(int_ty(toBits), val);
toType = int_ty(toBits);
}

return toType;
}

// Trunc the value to specific length and then cast it to given type if
// necessary. This function is typically used in conjunction with
// castToAndSExtInt.
static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc,
Value val, Type valType,
unsigned fromBits) {
unsigned originalBits = valType.getIntOrFloatBitWidth();
Value toVal = val;

if (originalBits < fromBits) {
toVal = trunc(int_ty(originalBits), toVal);
}

if (!valType.isIntOrIndex()) {
toVal = bitcast(toVal, valType);
}

return toVal;
}

bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce,
unsigned interleave) const {
return false;
if (numLaneToReduce != 64)
return false;

if (auto family = getISAFamily();
family != ISAFamily::CDNA3 && family != ISAFamily::CDNA2) {
return false;
}

Operation *reduxOp = op.getSingleCombiner();
if (!reduxOp)
return false;

auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src,
uint32_t dppCtrl, int rowMask,
int bankMask) -> Value {
// DPP has limited support for data types, so here we need to
// cast non-integer types or integer types shorter than 32 bits
// to int32, except for fp32.
Type actualType = valType;
if (!valType.isF32()) {
actualType = castToAndSExtInt(rewriter, loc, src, valType, 32);
}

Value dppResult =
rewriter
.create<ROCDL::DPPUpdateOp>(loc, actualType, src, src,
rewriter.getI32IntegerAttr(dppCtrl),
rewriter.getI32IntegerAttr(rowMask),
rewriter.getI32IntegerAttr(bankMask),
rewriter.getBoolAttr(true))
.getRes();

if (!valType.isF32()) {
src = truncAndCastFromInt(rewriter, loc, src, valType, 32);
dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32);
}

IRMapping mapping;
mapping.map(reduxOp->getOperand(0), src);
mapping.map(reduxOp->getOperand(1), dppResult);
return rewriter.clone(*reduxOp, mapping)->getResult(0);
};

for (int i = 0; i < acc.size(); i++) {
Value buf;
auto valType = acc[i].getType();

/*
Here's the implementation of full-wavefront reduction using dpp.
https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
Each step has a v_mov_dpp instruction following the redux op. In
some cases, the lower-level compiler could merge them into single
instruction. For example, v_mov_dpp + max => v_max_dpp.
For gfx9, we have 64 threads per warp. These 64 threads are arranged
into 4 rows, with each row being 16 threads. Each 16 threads are arranged
further into 4 banks, with each bank being 4 threads. Overall it's in a
(row, bank, thread) structure. When shuffling, we use row/bank mask to
indicate which row/bank to participate. Then modifier like row_shr and
row_bcast means exact data movement schemes. In the following
instructions, taking row 0 as an example:
Step 1: Right shift for 8 lanes.
lane 8-15 = redux(lane 0-7, lane 8-15)
Step 2: Right shift for 4 lanes.
lane 12-15 = redux(lane 8-11, lane 12-15)
Step 3: Right shift for 2 lanes.
lane 14-15 = redux(lane 12-13, lane 14-15)
Step 4: Right shift for 1 lane.
lane 15 = redux(lane 14, lane 15)
Step 5: Broadcast lane 15 of each row to all the lanes of its next row.
lane 16-31 = redux(lane 15, lane 16-31)
Step 6: Broadcast lane 31 to lane 32-63.
lane 32-63 = redux(lane 31, lane 32-63)
Now the reduction result is stored in lane 63.
Step 7: Read the reduction result from lane 63 and broadcast with
readlane.
*/

const int allRows = 0xf;
const int allBanks = 0xf;

const uint32_t dppCtrlRowShr = static_cast<uint32_t>(DppCtrl::ROW_SHR0);

// row_shr:8
buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr,
allRows, allBanks);

// row_shr:4
buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr,
allRows, allBanks);

// row_shr:2
buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr,
allRows, allBanks);

// row_shr:1
buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr,
allRows, allBanks);

// row_bcast:15 row_mask:0xa
buf = createDppReduxOpWithBoundCtrl(
valType, buf, static_cast<uint32_t>(DppCtrl::BCAST15), 0xa, allBanks);

// row_bcast:31
buf = createDppReduxOpWithBoundCtrl(valType, buf,
static_cast<uint32_t>(DppCtrl::BCAST31),
allRows, allBanks);

// Similarly, we need to cast data types for readlane instruction.
Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16);

// Get reduction result from lane 63
std::string intrinsic = "llvm.amdgcn.readlane";
Value result =
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType,
ValueRange{buf, i32_val(63)})
->getResult(0);

result = truncAndCastFromInt(rewriter, loc, result, valType, 16);

acc[i] = result;
}

return true;
}

void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount,
Expand Down
Loading

0 comments on commit 21119e3

Please sign in to comment.