Skip to content

Commit

Permalink
[BACKEND][NVIDIA] Add DotOp Hoisting Pass for WGMMA and Add Lowering …
Browse files Browse the repository at this point in the history
…for SMEM-to-MMAv3 DotOp Copy (triton-lang#5003)

Hopper has two kinds of WGMMAs, "SS" (both operands in shmem) and "RS"
(LHS operand A in registers).
In cases where we apply elementwise operations on A before WGMMA, Triton
previously will copy A from global memory (GMEM) into registers (RF),
perform the elementwise ops, and then copy to shared memory (SMEM) to
perform SS WGMMA.

This PR adds an optimization for the case above to use RS GEMM. This
requires the following changes:

- In TritonGPU OptimizeDotOperands pass, add optimizations to change SS
GEMM into RS GEMM.
- Add TritonGPU -> LLVM lowering for copying from SMEM to RF in MMA v3
dotOperand layout.

NOTE: This may not see perf gain, and may even see perf loss, for
certain shapes (e.g. small-K), and additional optimizations are in a
separate [PR](openxla#19) (still more
optimizations are WIP). Please advise on the merging strategy.
  • Loading branch information
ggengnv authored and hmalgewatta committed Nov 15, 2024
1 parent bd483c5 commit 6b092ae
Show file tree
Hide file tree
Showing 9 changed files with 512 additions and 69 deletions.
8 changes: 7 additions & 1 deletion include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,13 @@ int getNVIDIAComputeCapability(Operation *module);
std::optional<mlir::triton::gpu::SharedEncodingAttr>
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);

bool loadIsMMAv3(Operation *loadOp);
enum class MMALoadType {
SharedV3,
Registers, // may be v2 or v3
DoNotPipeline, // could be a valid shared/registers MMA operand, but skip
// pipelining
};
MMALoadType getMMALoadType(Operation *loadOp);
} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
11 changes: 9 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,18 @@ filterPipelinedLoad(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>

bool hasSharedEncoding = false;
if (use->hasTrait<OpTrait::DotLike>()) {
if (loadIsMMAv3(op)) {
auto mmaLoadType = getMMALoadType(op);
auto dot = dyn_cast<tt::DotOp>(use);
auto warpGroupDot = dyn_cast<ttng::WarpGroupDotOp>(use);
bool isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3;
bool isMMAv3Registers =
(mmaLoadType == MMALoadType::Registers) && warpGroupDot;

if (isMMAv3Shared) {
hasSharedEncoding = true;
} else if (isa<tt::ExperimentalDescriptorLoadOp>(op)) {
hasSharedEncoding = true;
} else if (auto dot = dyn_cast<tt::DotOp>(use)) {
} else if (isMMAv3Registers || dot) {
// FIXME: if we have a better solution in handling incompatible shared
// encoding, we can simplify the logic here by checking if all users are
// dot encoding. Fow now, getSharedEncIfAllUsersAreDotEnc will be used
Expand Down
287 changes: 270 additions & 17 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
Expand All @@ -15,6 +17,125 @@ namespace gpu {

namespace {

// Helpers

// Returns whether we can hoist DotOp Encoding through `op`.
// Roughly, whether op is elementwise and thus threads don't need
// to exchange elements. But some ops are not currently supported even though
// they meet that criterion.
bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) {
// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) &&
!isa<arith::ArithDialect>(op->getDialect()))
return false;

// Quick handling to fix loading issues when computing the original
// bitwidth is unable to realize that there is a mixed-precision dot
// (hence kWidth = 1) but wants to hoist through the type conversion.
if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth() == 1)
return false;

// Currently, these instructions are not supported during lowering of
// shared -> dot_operand layout. Not all types and type conversions are
// supported.
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
return false;

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(op)) {
Type opType = getElementTypeOrSelf(op->getOperand(0));
if (opType.isInteger(1))
return false;
}

return true;
}

// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
// is in registers).
bool canHoistDotOpEncV3(Operation *op) {
// Must have exactly one result and at least one operand
if (op->getNumOperands() == 0 || op->getNumResults() != 1)
return false;

auto isBlockedOrDotOpRankedTensor = [](Type ty) {
auto tensorTy = dyn_cast<RankedTensorType>(ty);
if (!tensorTy)
return false;
return isa<BlockedEncodingAttr, DotOperandEncodingAttr>(
tensorTy.getEncoding());
};

// Operands and results must be of RankedTensorType and Blocked or DotOp
if (!(all_of(op->getOperandTypes(), isBlockedOrDotOpRankedTensor) &&
all_of(op->getResultTypes(), isBlockedOrDotOpRankedTensor)))
return false;

// Only consider custom conversions or arith ops.
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) &&
!isa<arith::ArithDialect>(op->getDialect()))
return false;

// Currently, these instructions are not supported during lowering of
// shared -> dot_operand layout. Not all types and type conversions are
// supported.
if (isa<arith::SelectOp>(op))
return false;

// Downcasting not currently supported; it will likely require minor
// adjustments in sharedToDotOperandMMv2
auto oprType = getElementTypeOrSelf(op->getOperand(0));
auto resType = getElementTypeOrSelf(op->getResult(0));
if (oprType.getIntOrFloatBitWidth() > resType.getIntOrFloatBitWidth())
return false;

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(op) && oprType.isInteger(1))
return false;

return true;
}

// Helper to perform a "deep" clone of the given slice (i.e., set of ops),
// returning a tuple (newSlice, sliceMap), where newSlice is the cloned slice,
// and sliceMap the IRMapping that maps the ops and result values of the
// original slice to those in the cloned slice.
auto cloneSlice(PatternRewriter &rewriter,
const SetVector<Operation *> &slice) {
IRMapping sliceMap;
SetVector<Operation *> newSlice;

// First pass: clone ops; the result values are cloned as well, but the
// operands still refer to the original result values
for (Operation *op : slice) {
rewriter.setInsertionPoint(op);
auto newOp = rewriter.clone(*op);
newSlice.insert(newOp);
sliceMap.map(op, newOp);
for (auto [result, newResult] :
llvm::zip(op->getResults(), newOp->getResults())) {
assert(result != newResult);
sliceMap.map(result, newResult);
}
}

// Second pass: replace operand references in cloned ops to point to cloned
// values
for (auto [op, newOp] : sliceMap.getOperationMap())
for (auto [oprIdx, operand] : llvm::enumerate(newOp->getOperands())) {
auto defOp = operand.getDefiningOp();
if (!slice.contains(defOp))
continue;

newOp->setOperand(oprIdx, sliceMap.lookup(operand));
}

return std::make_tuple(newSlice, sliceMap);
}

// Given
// convert(trans(src)) #dot_operand ->
// convert(local_load(trans(alloc(src))))
Expand Down Expand Up @@ -111,7 +232,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cast<RankedTensorType>(cvt.getType());
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
Expand All @@ -126,16 +248,7 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
[](Type ty) { return isa<RankedTensorType>(ty); }))
return failure();

// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
src->getDialect()->getTypeID() != TypeID::get<arith::ArithDialect>())
return failure();

// Currently, these instructions are not supported during lowering of
// shared -> dot_operand layout. Not all types and type conversions are
// supported.
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
if (!canHoistDotOpEncV2(src, dotOpEnc))
return failure();

// Check that the conversion is transitively dependent on a load, and all
Expand Down Expand Up @@ -165,12 +278,7 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
if (isa<LoadOp>(currOp)) {
foundLoad = true;
} else if (foundLoad) {
// Bail out if there exists an op after Load that is not FpToFp,
// Bitcast, or Arith.
if (!isa<FpToFpOp, BitcastOp>(currOp) &&
!isPureUnaryInlineAsm(currOp) &&
currOp->getDialect()->getTypeID() !=
TypeID::get<arith::ArithDialect>())
if (!canHoistDotOpEncV2(currOp, dotOpEnc))
return failure();
}
}
Expand Down Expand Up @@ -301,6 +409,150 @@ struct MMAV3UseRegOperand
}
};

// MMAV3's analog of HoistLayoutConversion, for operand A only; will make
// WarpGroupDot accept operand A in registers instead of shmem.
//
// Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot
// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+;
// warp_group_dot
//
// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a
// time and requires multiple passes, this pattern will directly hoist the
// convert to the right place in one pass.
//
// Or, to be more precise, this pattern deletes the local_alloc op and inserts a
// convert_layout op after each load that warp_group_dot uses; so this is not
// simply hoisting a convert_layout op up as in V2, but can be considered as
// first changing local_alloc to convert_layout and then hoisting, which results
// in WGMMA now accepting operand A in DotOp layout rather than Shared.
struct MMAV3HoistLayoutConversion
: public OpRewritePattern<triton::nvidia_gpu::WarpGroupDotOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp,
PatternRewriter &rewriter) const override {
// Can only hoist operand 0
auto alloc = dotOp.getOperand(0).getDefiningOp<LocalAllocOp>();
if (!alloc || !alloc.getSrc())
return rewriter.notifyMatchFailure(
dotOp, "operand A must be produced by local_alloc");

auto getEncoding = [](Value v) {
return cast<TensorOrMemDesc>(v.getType()).getEncoding();
};

if (!isa<SharedEncodingAttr>(getEncoding(dotOp.getOperand(0))))
return rewriter.notifyMatchFailure(
dotOp, "requires Shared encoding for operand A");

// Step 1: Performs checks for early stop
auto srcEnc = dyn_cast<BlockedEncodingAttr>(getEncoding(alloc.getSrc()));
if (!srcEnc)
return rewriter.notifyMatchFailure(
alloc, "requires src to have Blocked encoding");

auto dstEnc =
dyn_cast<NvidiaMmaEncodingAttr>(getEncoding(dotOp.getResult()));
if (!dstEnc || dstEnc.getVersionMajor() != 3)
return rewriter.notifyMatchFailure(
dotOp, "requires result in NvidiaMma encoding");

// Step 2: Obtain slice of ops between load/constant and local_alloc
SetVector<Operation *> slice;
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = [&](Operation *op) {
// Stop before Load, ConstantOp, or LocalLoad
return (op->getParentRegion() == alloc->getParentRegion()) &&
!isa<LoadOp, arith::ConstantOp, LocalLoadOp>(op) &&
(op->getNumOperands() != 0);
};
getBackwardSlice(alloc.getOperation(), &slice, opt);

// Step 3: Verify slice can be hoisted through
if (slice.empty())
return rewriter.notifyMatchFailure(dotOp, "nothing to hoist through");

// We define frontierOp as an op outside this slice whose result is used by
// an op in this slice. We must eventually convert the result of all
// frontierOps to DotOperandEncoding. This is done via the insertion of
// ConvertLayout after each frontierOp. We currently support frontierOp to
// be load or constant.
for (Operation *currOp : slice) {
if (!canHoistDotOpEncV3(currOp))
return rewriter.notifyMatchFailure(currOp, "cannot hoist through");

// We previously ensured that all ops in slice have at least one operand
for (auto operand : currOp->getOperands()) {
auto defOp = operand.getDefiningOp();
if (!slice.contains(defOp)) {
// ensure frontierOp is load or constant
if (!isa<LoadOp, arith::ConstantOp>(defOp))
return rewriter.notifyMatchFailure(defOp,
"must be load or constant");
}
}
}

// Step 4: Clone slice
auto [newSlice, sliceMap] = cloneSlice(rewriter, slice);

// Step 5: Modify the cloned slice to have dotOp encoding.
// Before: load #blocked; (elementwise #blocked)+; local_alloc;
// warp_group_dot After: load #blocked; convert_layout #dot_op;
// (elementwise #dot_op)+; warp_group_dot
//
// Specifically, this step will change all value types from #blocked to
// #dot_op encoding in the cloned slice, and for those values produced by
// frontierOps (i.e., outside the slice), we will insert convert_layout's
// after the frontierOp.
auto srcTy = cast<RankedTensorType>(alloc.getSrc().getType());
Type inputEltTy = srcTy.getElementType();
auto dotOperandEnc = DotOperandEncodingAttr::get(
dotOp.getContext(), /*opIdx=*/0, dstEnc, inputEltTy);

for (auto op : newSlice) {
// Step 5a: If any operand is defined by a frontierOp, we must insert a
// convert_layout(#dot_op) after the frontierOp and before currOp
for (auto [oprIdx, operand] : llvm::enumerate(op->getOperands())) {

auto defOp = operand.getDefiningOp();

// defOp is not frontier (i.e. it's within slice); no need to convert
// the layout of its result
if (newSlice.contains(defOp))
continue;

// We checked earlier that all operands are ranked tensors
auto operandTy = cast<RankedTensorType>(operand.getType());
auto operandEltTy = operandTy.getElementType();

Type cvtTy = RankedTensorType::get(
operandTy.getShape(), operandTy.getElementType(), dotOperandEnc);
rewriter.setInsertionPoint(op);
auto cvt =
rewriter.create<ConvertLayoutOp>(defOp->getLoc(), cvtTy, operand);

op->setOperand(oprIdx, cvt);
}

// Step 5b: Change the result to have DotOp rather than Blocked encoding
auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
op->getResult(0).setType(RankedTensorType::get(
resTy.getShape(), resTy.getElementType(), dotOperandEnc));
}

// Step 6: replace LHS operand with alloc's parent in the cloned slice
// This changes the warpGroupDot to accept a DotOp tensor as operand A
// instead of a Shared memdesc.
auto newDotOperand = sliceMap.lookup(alloc.getSrc());
rewriter.modifyOpInPlace(dotOp,
[&]() { dotOp.setOperand(0, newDotOperand); });

return success();
}
};

} // namespace

#define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS
Expand All @@ -322,6 +574,7 @@ class TritonGPUOptimizeDotOperandsPass
auto ret = pm.run(m);

mlir::RewritePatternSet patterns(context);
patterns.add<MMAV3HoistLayoutConversion>(context);
patterns.add<SwizzleShmemConvert>(context);
if (this->hoistLayoutConversion.getValue())
patterns.add<HoistLayoutConversion>(context);
Expand Down
Loading

0 comments on commit 6b092ae

Please sign in to comment.