Skip to content

Commit

Permalink
* Removed outer loop pipelining. It does not improve perf and may be …
Browse files Browse the repository at this point in the history
…replaced with loop fusion

* Reorder will not move loads/local_stores over loops
  • Loading branch information
sjw36 committed Jul 16, 2024
1 parent 77f837a commit 9517277
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 166 deletions.
38 changes: 19 additions & 19 deletions test/TritonGPU/amd/amd-reorder-instructions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -350,31 +350,31 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: %{{.*}}:9 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}-1_i32, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}, %[[ARG15:.*]] = %{{.*}})

// CHECK: %[[SUBI_25:.*]] = arith.subi %{{.*}}, %{{.*}}
// CHECK: %[[CMPI_26:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_25]]
// CHECK: %[[SPLAT_27:.*]] = tt.splat %[[CMPI_26]]
// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG8]], %{{.*}}
// CHECK: %[[LOAD_29:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_27]]
// CHECK: %[[ADDPTR_30:.*]] = tt.addptr %[[ARG9]], %{{.*}}
// CHECK: %[[LOAD_31:.*]] = tt.load %[[ADDPTR_30]], %[[CMPI_26]]
// CHECK: %[[MULI_32:.*]] = arith.muli %{{.*}}, %[[LOAD_31]]
// CHECK: %[[SPLAT_33:.*]] = tt.splat %[[MULI_32]]
// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_26]]
// CHECK: %[[ADDPTR_35:.*]] = tt.addptr %{{.*}}, %[[SPLAT_33]]
// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_35]], %[[SPLAT_34]]
// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG11]], %{{.*}}
// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}}
// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}}
// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[ARG14]], %[[MEMDESC_SUBVIEW_40]]
// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[ARG15]], %[[MEMDESC_SUBVIEW_41]]
// CHECK: %[[ADDI_26:.*]] = arith.addi %[[ARG11]], %{{.*}}
// CHECK: %[[CMPI_27:.*]] = arith.cmpi slt, %[[ADDI_26]], %{{.*}}
// CHECK: %[[SELECT_28:.*]] = arith.select %[[CMPI_27]], %[[ADDI_26]], %{{.*}}
// CHECK: %[[MEMDESC_SUBVIEW_29:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_28]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[ARG14]], %[[MEMDESC_SUBVIEW_29]]
// CHECK: %[[MEMDESC_SUBVIEW_30:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_28]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[ARG15]], %[[MEMDESC_SUBVIEW_30]]
// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_25]]
// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]]
// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG8]], %{{.*}}
// CHECK: %[[LOAD_34:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_32]]
// CHECK: %[[ADDPTR_35:.*]] = tt.addptr %[[ARG9]], %{{.*}}
// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_35]], %[[CMPI_31]]
// CHECK: %[[MULI_37:.*]] = arith.muli %{{.*}}, %[[LOAD_36]]
// CHECK: %[[SPLAT_38:.*]] = tt.splat %[[MULI_37]]
// CHECK: %[[SPLAT_39:.*]] = tt.splat %[[CMPI_31]]
// CHECK: %[[ADDPTR_40:.*]] = tt.addptr %{{.*}}, %[[SPLAT_38]]
// CHECK: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_40]], %[[SPLAT_39]]
// CHECK: %[[ADDI_42:.*]] = arith.addi %[[ARG10]], %{{.*}}
// CHECK: %[[CMPI_43:.*]] = arith.cmpi slt, %[[ADDI_42]], %{{.*}}
// CHECK: %[[SELECT_44:.*]] = arith.select %[[CMPI_43]], %[[ADDI_42]], %{{.*}}
// CHECK: %[[LOCAL_LOAD_45:.*]] = triton_gpu.local_load %[[ARG12]]
// CHECK: %[[LOCAL_LOAD_46:.*]] = triton_gpu.local_load %[[ARG13]]
// CHECK: %[[DOT_47:.*]] = tt.dot %[[LOCAL_LOAD_45]], %[[LOCAL_LOAD_46]], %[[ARG7]]
// CHECK: scf.yield %[[DOT_47]], %[[ADDPTR_28]], %[[ADDPTR_30]], %[[SELECT_44]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_29]], %[[LOAD_36]]
// CHECK: scf.yield %[[DOT_47]], %[[ADDPTR_33]], %[[ADDPTR_35]], %[[SELECT_44]], %[[SELECT_28]], %[[MEMDESC_SUBVIEW_29]], %[[MEMDESC_SUBVIEW_30]], %[[LOAD_34]], %[[LOAD_41]]
// CHECK: }

tt.func @indirect_bmm_scalar(%arg0: i64 {tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr<f16>, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr<f16>, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ findEarlyInsertionPoint(Block *block, Operation *move, Value src) {
op->walk([&](Operation *wop) {
if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(wop))
loc = bi;
if (isa<scf::ForOp, scf::WhileOp>(wop))
loc = bi;
});
}
return loc;
Expand Down Expand Up @@ -152,8 +154,9 @@ class TritonAMDGPUReorderInstructionsPass
// Move local_stores early if dependence distance greater than
// one iteration. Best perf on GEMM when these precede global loads.
m.walk([&](triton::gpu::LocalStoreOp op) { moveOps.push_back(op); });

for (auto op : moveOps) {
// 0. Gather use-def chain in block.
// Gather use-def chain in block.
Block *block = op->getBlock();
SmallVector<Operation *> dfg{op};
bool leadsToLoad = gatherDFG(op, block, dfg);
Expand Down
199 changes: 53 additions & 146 deletions third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
// Software pipeliners are usually separated into two pieces, one that create a
// modulo schedule and an expander that rewrites the loop and emits a prologue
// and epilogue. This pass first calls a helper that will pre-process the IR
// to create async operations and create a modulo schedule. Then we call the
// to create stream operations and create a modulo schedule. Then we call the
// expander to generate the prologue and new loop.
//===----------------------------------------------------------------------===//

Expand All @@ -41,9 +41,6 @@ using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;

// TODO: We can extra some helpers into common utilities once we add more
// schedules.

namespace {

struct LoadInfo {
Expand All @@ -69,12 +66,12 @@ static void appendToYield(scf::ForOp forOp, ArrayRef<Value> newOperands) {
yieldOp->erase();
}

static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
Value insertIdx, Value extractIdx,
tt::CoarseSchedule &schedule,
tt::CoarseSchedule::Cluster prefetchCluster,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
int numStages) {
static void createStreamCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
Value insertIdx, Value extractIdx,
tt::CoarseSchedule &schedule,
tt::CoarseSchedule::Cluster prefetchCluster,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
int numStages) {
OpBuilder builder(forOp);
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
// Replace the load with insert/extract slice.
Expand Down Expand Up @@ -140,8 +137,7 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
builder.create<ttg::LocalLoadOp>(loc, loadOp.getType(), viewLoad);
auto result = sharedLoad->getResults();

// Create a select for non-zero other values as they are not handled by
// AsyncCopyGlobalToLocalOp for now.
// Create a select for non-zero other values.
Value other = loadOp.getOther();
if (other && !isZeroConst(other)) {
auto select = builder.create<arith::SelectOp>(
Expand Down Expand Up @@ -235,7 +231,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) {
[&](Operation *op, int distance, Operation *use) {
if (!seen.insert(op).second)
return;
if (isa<tt::LoadOp, tt::ExperimentalDescriptorLoadOp>(op)) {
if (isa<tt::LoadOp>(op)) {
// TODO: What if there are multiple uses at different distances?
loadOpToIndLevelAndUse.push_back(std::make_tuple(op, distance, use));
use = op;
Expand All @@ -261,7 +257,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) {
// that are not directly used by dot ops.
if (forOp->hasAttr(tt::kNumStagesAttrName)) {
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!isa<tt::LoadOp, tt::ExperimentalDescriptorLoadOp>(op))
if (!isa<tt::LoadOp>(op))
dfs(&op, 0, &op);
}
}
Expand All @@ -281,32 +277,28 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
continue;
LoadInfo loadInfo;

if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
assert(!isLoadFromTensorPtr(loadOp) &&
"Block ptr should have been lowered before this pass.");
auto ptr = loadOp.getPtr();
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
if (auto mask = loadOp.getMask())
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
auto loadOp = dyn_cast<tt::LoadOp>(op);
assert(!isLoadFromTensorPtr(loadOp) &&
"Block ptr should have been lowered before this pass.");
auto ptr = loadOp.getPtr();
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
if (auto mask = loadOp.getMask())
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));

auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
if (!tensorTy)
continue;
auto ty =
cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
unsigned width = vec * ty.getIntOrFloatBitWidth();

// We do not pipeline all loads for the following reasons:
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16.
// 2. It's likely that pipling small loads won't offer much performance
// improvement and may even hurt performance by increasing register
// pressure.
LDBG("Load " << *loadOp << " has width " << width);
if (width < 32)
continue;
}
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
if (!tensorTy)
continue;

auto ty = cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
unsigned width = vec * ty.getIntOrFloatBitWidth();

// Limit shared memory sharing to width >= 32 elements.
LDBG("Load " << *loadOp << " has width " << width);
if (width < 32)
continue;

if (use->hasTrait<OpTrait::DotLike>()) {
// Only use shared memory when feeding a dot op
loadInfo.usedByDot = true;
loadInfo.sharedEncoding =
getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr);
Expand All @@ -327,9 +319,7 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
// encoding.
if (!loadInfo.sharedEncoding) {
// Also pipeline in-register buffers.
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis);
}
loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis);
}

loadToInfo[op] = loadInfo;
Expand Down Expand Up @@ -412,66 +402,6 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
return loadToInfo;
}

// Schedule the prologue and epilogue `if` ops in the loop, pushing them as
// close to the loop boundaries as possible. Return the cluster after the
// prologue (or the beginning of the loop if there is no prologue).
static tt::CoarseSchedule::Cluster
schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule,
DenseSet<Operation *> &rootUsers, int numStages) {
tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin();

// Look for the IfOp that is in the backward slice any of the currently
// scheduled ops and put it at the beginning of the loop.
DenseMap<scf::IfOp, int> ifsToStage;
// Go stage by stage.
for (int stage = 0; stage < numStages; stage++) {
for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) {
if (stage_ != stage)
continue;
SetVector<Operation *> backwardSlice;
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice((Operation *)op, &backwardSlice, opt);

for (auto op : backwardSlice) {
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
ifsToStage.insert({ifOp, stage});
}
}
}
}
tt::CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront();
for (auto [ifOp, stage] : ifsToStage) {
schedule.insert(ifOp, stage, prologueCluster);
}

// Look for the IfOp that is in the forward slice of the root users and put it
// at the end of the loop.
tt::CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack();
for (auto rootUser : rootUsers) {
SetVector<Operation *> forwardSlice;
getForwardSlice(rootUser, &forwardSlice);

int stage = schedule[rootUser].first;
for (auto op : forwardSlice) {
scf::IfOp ifOp = dyn_cast<scf::IfOp>(op);
if (ifOp == nullptr) {
// check if the op is in the body of an if op that's part of the loop
auto parentOp = op->getParentOp();
if (parentOp != nullptr &&
parentOp->getParentOp() == forOp.getOperation()) {
ifOp = dyn_cast<scf::IfOp>(parentOp);
}
}
if (ifOp) {
schedule.insertIfAbsent(ifOp, stage,
epilogueCluster); // after prefetch extracts
}
}
}
return afterPrologue;
}

// Add dependencies of anchor ops to the coarse schedule. Schedule them to
// the same stage and ordering cluster as the anchor op.
static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule,
Expand Down Expand Up @@ -600,9 +530,9 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
// Convert load ops into their asyn version and apply multi-buffering based on
// the required number of buffers.
static SmallVector<Value>
createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
int numStages) {
createStreamOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
int numStages) {
// Calculate the number of buffers needed for each load.
// TODO pawel: we could do more fine-grained allocation here and
// allocate only the number of buffers that specific loads need.
Expand Down Expand Up @@ -677,8 +607,8 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule,

for (auto &pair : asyncLoads) {
if (auto loadOp = dyn_cast<tt::LoadOp>(pair.first)) {
createAsyncCopy(forOp, loadOp, pair.second, insertIdx, extractIdx,
schedule, prefetchCluster, loadToInfo, numStages);
createStreamCopy(forOp, loadOp, pair.second, insertIdx, extractIdx,
schedule, prefetchCluster, loadToInfo, numStages);
}
}
SmallVector<Value> newYieldOperands = {insertIdx, extractIdx};
Expand Down Expand Up @@ -709,19 +639,14 @@ preProcessLoopAndGetSchedule2(scf::ForOp &forOp, int numStages,

// Convert the loads into async loads and create the allocs.
SmallVector<Value> allocs =
createAsyncOps(forOp, coarseSchedule, loadToInfo, numStages);
createStreamOps(forOp, coarseSchedule, loadToInfo, numStages);

LLVM_DEBUG({
LDBG("Coarse schedule with async loads:");
LDBG("Coarse schedule with stream loads:");
coarseSchedule.dump();
});

tt::CoarseSchedule::Cluster afterPrologue =
schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages);
LLVM_DEBUG({
LDBG("Coarse schedule with prologue and epilogue:");
coarseSchedule.dump();
});
tt::CoarseSchedule::Cluster afterPrologue = coarseSchedule.clusters.begin();

scheduleDependencies(forOp, coarseSchedule, numStages);
LLVM_DEBUG({
Expand Down Expand Up @@ -768,7 +693,7 @@ preProcessLoopAndGetSchedule2(scf::ForOp &forOp, int numStages,
}

// Return true if the preconditions for pipelining the loop are met.
static bool preCondition(scf::ForOp forOp) {
static bool preConditionInner(scf::ForOp forOp) {
// Skip loop with distance > 1 for now.
// TODO: relax the constraint in the expander.
if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
Expand All @@ -791,22 +716,9 @@ static bool preCondition(scf::ForOp forOp) {
return true;
}

static void tryAndPipelineOuterLoop(scf::ForOp forOp) {
mlir::triton::PipeliningOption options;
bool foundSchedule = false;
// Limit 2 stages to not require extra shared memory.
foundSchedule = getOuterLoopSchedule(forOp, /*numStage=*/2, options);
if (!foundSchedule)
return;
IRRewriter rewriter(forOp->getContext());
rewriter.setInsertionPoint(forOp);
FailureOr<scf::ForOp> newForOp =
mlir::triton::pipelineForLoop(rewriter, forOp, options);
}

static bool pipelineLoop(scf::ForOp forOp, int numStages) {
mlir::triton::PipeliningOption options;
if (!preCondition(forOp))
if (!preConditionInner(forOp))
return false;

bool foundSchedule = false;
Expand Down Expand Up @@ -851,29 +763,24 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
if (loops.empty())
return;

llvm::SmallSetVector<scf::ForOp, 8> outerLoops;
bool pipelined = false;
for (scf::ForOp forOp : loops) {
auto outerLoop = dyn_cast<scf::ForOp>(forOp->getParentOp());
int loopNumStages = getNumStagesOrDefault(forOp);
bool pipelined = pipelineLoop(forOp, loopNumStages);
if (pipelined && outerLoop && getNumStagesOrDefault(outerLoop) > 1)
outerLoops.insert(outerLoop);
pipelined |= pipelineLoop(forOp, loopNumStages);
}

// Clean up arithmetic before applying the next level of pipelining to
// simplify the IR.
auto arithDialect =
getOperation().getContext()->getLoadedDialect<arith::ArithDialect>();
RewritePatternSet patterns(getOperation().getContext());
arithDialect->getCanonicalizationPatterns(patterns);
if (applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))
.failed())
return signalPassFailure();

// Try to pipeline the outer loop to overlap the prologue and epilogue of
// the inner loop.
for (scf::ForOp outerLoop : outerLoops)
tryAndPipelineOuterLoop(outerLoop);
if (pipelined) {
// Clean up arithmetic before applying the next level of pipelining to
// simplify the IR.
auto arithDialect =
getOperation().getContext()->getLoadedDialect<arith::ArithDialect>();
RewritePatternSet patterns(getOperation().getContext());
arithDialect->getCanonicalizationPatterns(patterns);
if (applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))
.failed())
signalPassFailure();
}
}
};
} // anonymous namespace
Expand Down

0 comments on commit 9517277

Please sign in to comment.