Skip to content

Commit

Permalink
[BACKEND] Add a loop unroller pass (#4645)
Browse files Browse the repository at this point in the history
Adding a loop unroller pass which applies to only loops with unroll
annotation.

An annotated loop will look like:


```
    scf.for %arg5 = %c0_i32 to %arg3 step %c32_i32 : i32 {
      ...
    } {tt.loop_unroll_factor = 2 : i32}
```
  • Loading branch information
htyu authored Sep 9, 2024
1 parent e192dba commit 7df871d
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ std::unique_ptr<Pass> createCombineOpsPass();

std::unique_ptr<Pass> createReorderBroadcastPass();
std::unique_ptr<Pass> createRewriteTensorPointerPass();
std::unique_ptr<Pass> createLoopUnrollPass();

} // namespace triton

Expand Down
10 changes: 10 additions & 0 deletions include/triton/Dialect/Triton/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,14 @@ def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer"
let dependentDialects = ["mlir::triton::TritonDialect"];
}

def TritonLoopUnroll : Pass</*cli-arg*/"triton-loop-unroll", /*Op*/"mlir::ModuleOp"> {
let summary = "Loop unroller";
let description = [{
The pass unrolls a scf loop with tt.loop_unroll_factor attribute. The attribute specialises how many iterations
the loop should be unrolled.
}];
let constructor = "mlir::triton::createLoopUnrollPass()";
let dependentDialects = ["mlir::triton::TritonDialect"];
}

#endif
1 change: 1 addition & 0 deletions lib/Dialect/Triton/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_public_tablegen_target(TritonCombineIncGen)

add_triton_library(TritonTransforms
Combine.cpp
LoopUnroll.cpp
ReorderBroadcast.cpp
RewriteTensorPointer.cpp

Expand Down
67 changes: 67 additions & 0 deletions lib/Dialect/Triton/Transforms/LoopUnroll.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include <memory>

#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "llvm/Support/Debug.h"

#define GEN_PASS_CLASSES
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"

#define DEBUG_TYPE "triton-loop-unroll"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir::triton {

static const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";

namespace {

class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {

int getUnrollFactorOrDefault(scf::ForOp forOp) {
// Use the attribute attached to the loop if it exists otherwise set the
// factor to 1 to suppress the unrolling.
if (auto factor = forOp->getAttrOfType<IntegerAttr>(
mlir::triton::loopUnrollFactorAttrName))
return factor.getInt();
return 1;
}

public:
LoopUnrollPass() = default;
LoopUnrollPass(const LoopUnrollPass &) {}
void runOnOperation() override {
LDBG("Loop unroll pass");
SmallVector<scf::ForOp, 4> loops;
getOperation()->walk([&](scf::ForOp forOp) {
// Bail out for loops with unroll factor <= 1.
if (getUnrollFactorOrDefault(forOp) > 1)
loops.push_back(forOp);
});

for (auto loop : loops) {
auto unrollFactor = getUnrollFactorOrDefault(loop);
loop->removeAttr(mlir::triton::loopUnrollFactorAttrName);
LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop);
(void)loopUnrollByFactor(loop, unrollFactor);
}
}
};

} // anonymous namespace

std::unique_ptr<mlir::Pass> createLoopUnrollPass() {
return std::make_unique<LoopUnrollPass>();
}

} // namespace mlir::triton
1 change: 1 addition & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void init_triton_passes_ttir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass);
ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer",
createRewriteTensorPointerPass);
ADD_PASS_WRAPPER_0("add_loop_unroll", createLoopUnrollPass);
ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir",
createConvertTritonToTritonGPUPass, const std::string &,
int, int, int);
Expand Down
45 changes: 45 additions & 0 deletions test/Triton/loop-unroll.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: triton-opt --split-input-file %s -triton-loop-unroll | FileCheck %s

tt.func @add_kernel_unroll(%arg0: tensor<256x!tt.ptr<f32>>, %arg1: i32) {
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant 0.000000e+00 : f32
%0 = tt.splat %c1_i32 : i32 -> tensor<256xi32>
%1 = tt.splat %cst : f32 -> tensor<256xf32>
// Check the loop is unrolled by factor of 2 and is followed by a reminder loop.
// CHECK-LABEL: add_kernel_unroll
// CHECK: scf.for
// CHECK-COUNT-2: tt.load
// CHECK-NOT: tt.load
// CHECK: scf.for
// CHECK: tt.load
// CHECK-NOT: tt.load
%2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>) : i32 {
%3 = tt.load %arg5 : tensor<256x!tt.ptr<f32>>
%4 = arith.addf %arg4, %3 : tensor<256xf32>
%5 = tt.addptr %arg5, %0 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
scf.yield %4, %5 : tensor<256xf32>, tensor<256x!tt.ptr<f32>>
} {tt.loop_unroll_factor = 2 : i32}
tt.return
}

// -----

tt.func @add_kernel_nounroll(%arg0: tensor<256x!tt.ptr<f32>>, %arg1: i32) {
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant 0.000000e+00 : f32
%0 = tt.splat %c1_i32 : i32 -> tensor<256xi32>
%1 = tt.splat %cst : f32 -> tensor<256xf32>
// Check the loop is not unrolled.
// CHECK-LABEL: add_kernel_nounroll
// CHECK: scf.for
// CHECK-COUNT-1: tt.load
// CHECK-NOT: tt.load
// CHECK-NOT: scf.for
%2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>) : i32 {
%3 = tt.load %arg5 : tensor<256x!tt.ptr<f32>>
%4 = arith.addf %arg4, %3 : tensor<256xf32>
%5 = tt.addptr %arg5, %0 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
scf.yield %4, %5 : tensor<256xf32>, tensor<256x!tt.ptr<f32>>
}
tt.return
}
1 change: 1 addition & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def make_ttir(mod, metadata, opt):
passes.common.add_cse(pm)
passes.common.add_licm(pm)
passes.common.add_symbol_dce(pm)
passes.ttir.add_loop_unroll(pm)
pm.run(mod)
return mod

Expand Down

0 comments on commit 7df871d

Please sign in to comment.