-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BACKEND] Add a loop unroller pass (#4645)
Adding a loop unroller pass which applies to only loops with unroll annotation. An annotated loop will look like: ``` scf.for %arg5 = %c0_i32 to %arg3 step %c32_i32 : i32 { ... } {tt.loop_unroll_factor = 2 : i32} ```
- Loading branch information
Showing
7 changed files
with
126 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#include <memory> | ||
|
||
#include "mlir/Dialect/SCF/Utils/Utils.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Support/LLVM.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "triton/Analysis/Utility.h" | ||
#include "triton/Dialect/Triton/IR/Dialect.h" | ||
#include "triton/Dialect/Triton/Transforms/Passes.h" | ||
#include "llvm/Support/Debug.h" | ||
|
||
#define GEN_PASS_CLASSES | ||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc" | ||
|
||
#define DEBUG_TYPE "triton-loop-unroll" | ||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") | ||
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") | ||
|
||
namespace mlir::triton { | ||
|
||
static const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor"; | ||
|
||
namespace { | ||
|
||
class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> { | ||
|
||
int getUnrollFactorOrDefault(scf::ForOp forOp) { | ||
// Use the attribute attached to the loop if it exists otherwise set the | ||
// factor to 1 to suppress the unrolling. | ||
if (auto factor = forOp->getAttrOfType<IntegerAttr>( | ||
mlir::triton::loopUnrollFactorAttrName)) | ||
return factor.getInt(); | ||
return 1; | ||
} | ||
|
||
public: | ||
LoopUnrollPass() = default; | ||
LoopUnrollPass(const LoopUnrollPass &) {} | ||
void runOnOperation() override { | ||
LDBG("Loop unroll pass"); | ||
SmallVector<scf::ForOp, 4> loops; | ||
getOperation()->walk([&](scf::ForOp forOp) { | ||
// Bail out for loops with unroll factor <= 1. | ||
if (getUnrollFactorOrDefault(forOp) > 1) | ||
loops.push_back(forOp); | ||
}); | ||
|
||
for (auto loop : loops) { | ||
auto unrollFactor = getUnrollFactorOrDefault(loop); | ||
loop->removeAttr(mlir::triton::loopUnrollFactorAttrName); | ||
LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop); | ||
(void)loopUnrollByFactor(loop, unrollFactor); | ||
} | ||
} | ||
}; | ||
|
||
} // anonymous namespace | ||
|
||
std::unique_ptr<mlir::Pass> createLoopUnrollPass() { | ||
return std::make_unique<LoopUnrollPass>(); | ||
} | ||
|
||
} // namespace mlir::triton |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// RUN: triton-opt --split-input-file %s -triton-loop-unroll | FileCheck %s | ||
|
||
tt.func @add_kernel_unroll(%arg0: tensor<256x!tt.ptr<f32>>, %arg1: i32) { | ||
%c1_i32 = arith.constant 1 : i32 | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%0 = tt.splat %c1_i32 : i32 -> tensor<256xi32> | ||
%1 = tt.splat %cst : f32 -> tensor<256xf32> | ||
// Check the loop is unrolled by factor of 2 and is followed by a reminder loop. | ||
// CHECK-LABEL: add_kernel_unroll | ||
// CHECK: scf.for | ||
// CHECK-COUNT-2: tt.load | ||
// CHECK-NOT: tt.load | ||
// CHECK: scf.for | ||
// CHECK: tt.load | ||
// CHECK-NOT: tt.load | ||
%2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>) : i32 { | ||
%3 = tt.load %arg5 : tensor<256x!tt.ptr<f32>> | ||
%4 = arith.addf %arg4, %3 : tensor<256xf32> | ||
%5 = tt.addptr %arg5, %0 : tensor<256x!tt.ptr<f32>>, tensor<256xi32> | ||
scf.yield %4, %5 : tensor<256xf32>, tensor<256x!tt.ptr<f32>> | ||
} {tt.loop_unroll_factor = 2 : i32} | ||
tt.return | ||
} | ||
|
||
// ----- | ||
|
||
tt.func @add_kernel_nounroll(%arg0: tensor<256x!tt.ptr<f32>>, %arg1: i32) { | ||
%c1_i32 = arith.constant 1 : i32 | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%0 = tt.splat %c1_i32 : i32 -> tensor<256xi32> | ||
%1 = tt.splat %cst : f32 -> tensor<256xf32> | ||
// Check the loop is not unrolled. | ||
// CHECK-LABEL: add_kernel_nounroll | ||
// CHECK: scf.for | ||
// CHECK-COUNT-1: tt.load | ||
// CHECK-NOT: tt.load | ||
// CHECK-NOT: scf.for | ||
%2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>) : i32 { | ||
%3 = tt.load %arg5 : tensor<256x!tt.ptr<f32>> | ||
%4 = arith.addf %arg4, %3 : tensor<256xf32> | ||
%5 = tt.addptr %arg5, %0 : tensor<256x!tt.ptr<f32>>, tensor<256xi32> | ||
scf.yield %4, %5 : tensor<256xf32>, tensor<256x!tt.ptr<f32>> | ||
} | ||
tt.return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters