Skip to content

Commit

Permalink
[NFC] Remove unused forOp argument from setStageCluster (#5288)
Browse files Browse the repository at this point in the history
<git-pr-chain>


[NFC] Remove unused forOp argument from `setStageCluster`


#### [PR chain](https://github.com/jlebar/git-pr-chain)
1. 👉 #5288 👈 **YOU ARE HERE**
1. #5289
1. #5290


</git-pr-chain>
  • Loading branch information
peterbell10 authored Nov 30, 2024
1 parent 27e11ab commit 912e595
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
// Return the minClusterId and maxClusterId for the given ForOp.
std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp);
std::pair<int, int> getStageCluster(Operation *op);
void setStageCluster(scf::ForOp &forOp, Operation *op, int stage, int cluster);
void setStageCluster(Operation *op, int stage, int cluster);
} // namespace triton
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,7 @@ class OpBuilderWithStage : public OpBuilder {
OpTy createWithStage(Location location, int stage, int cluster,
Args &&...args) {
OpTy op = OpBuilder::create<OpTy>(location, std::forward<Args>(args)...);
auto ctx = getContext();
op->setAttr(mlir::triton::kLoopStageAttrName,
IntegerAttr::get(IntegerType::get(ctx, 32), stage));
op->setAttr(mlir::triton::kLoopClusterAttrName,
IntegerAttr::get(IntegerType::get(ctx, 32), cluster));
tt::setStageCluster(op, stage, cluster);
return op;
}
using OpBuilder::create;
Expand Down Expand Up @@ -204,9 +200,8 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
// Prefetch load if is not MMAV3 and is used by the dot.
if (loadToInfo[loadOp].usedByDot) {
assert(stageForFirstUse >= 1);
tt::setStageCluster(forOp, wait, stageForFirstUse - 1, maxClusterId + 1);
tt::setStageCluster(forOp, viewLoad, stageForFirstUse - 1,
maxClusterId + 1);
tt::setStageCluster(wait, stageForFirstUse - 1, maxClusterId + 1);
tt::setStageCluster(viewLoad, stageForFirstUse - 1, maxClusterId + 1);
retCode = stageForFirstUse - 1;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,8 @@ std::pair<int, int> mlir::triton::getStageCluster(Operation *op) {
return std::make_pair(stage, clusterId);
}

void mlir::triton::setStageCluster(scf::ForOp &forOp, Operation *op, int stage,
int cluster) {
auto ctx = forOp.getContext();
void mlir::triton::setStageCluster(Operation *op, int stage, int cluster) {
auto ctx = op->getContext();
op->setAttr(mlir::triton::kLoopStageAttrName,
IntegerAttr::get(IntegerType::get(ctx, 32), stage));
op->setAttr(mlir::triton::kLoopClusterAttrName,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void tt::CoarseSchedule::dump() {
// Set <stage, cluster> based on CoarseSchedule.
void tt::CoarseSchedule::serialize(scf::ForOp &forOp) {
for (auto [op, stage, cluster] : getOpsInOrder(forOp)) {
tt::setStageCluster(forOp, op, stage, *cluster);
tt::setStageCluster(op, stage, *cluster);
}
}

Expand Down

0 comments on commit 912e595

Please sign in to comment.