Skip to content

Commit

Permalink
[PIPELINER] Cleanup of LoopScheduling.cpp, introduction of AssignLate…
Browse files Browse the repository at this point in the history
…ncies (#5176)

This change breaks down LoopScheduling into two sub-passes: latency
assignment and actual scheduling.
Latency assignment is a transformation that analyzes the loop and based
on the requested number of stages it assigns "latencies" to the ops that
are going to be converted to async ops by the pipeliner. Latencies are
expressed in terms of number of iterations of the loop and can be
thought as per-operation num_stages.
Scheduling transformation takes these latencies and builds a pipeliner
schedule based on it. The process of building a schedule was slightly
rewritten to simplify the code and cleanup the logic that was no longer
needed after recent refactoring.
Breaking down the schedule into latency assignment and proper scheduling
has number of purposes:
1. Code became more modular, with cleaner interfaces that helps with
maintanance
2. Both parts can be tested in separation, I have added lit tests for
both pieces. We can finally test our pipeliner infrastructure in
manageable chunks
3. It opens up opportunity to expose per-op "latencies" to the frontend,
enabling creating user-defined schedules right from the language level

Next step in the cleanup process is to clearly separate lowering and
pipelining phases.
  • Loading branch information
pawelszczerbuk authored Dec 4, 2024
1 parent 712ac66 commit 00cc5d0
Show file tree
Hide file tree
Showing 16 changed files with 1,327 additions and 391 deletions.
32 changes: 32 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,38 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
];
}

def TritonGPUTestPipelineAssignLatencies : Pass<"tritongpu-test-pipeline-assign-latencies", "mlir::ModuleOp"> {
let summary = "test assigning latencies to interesting ops ahead of pipelining";

let description = [{
This is a test pass that tests `assignLatencies` method of `TritonGPULoopScheduling`.
}];

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];

let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"3",
"number of pipeline stages">
];
}

def TritonGPUTestPipelineScheduleLoop : Pass<"tritongpu-test-pipeline-schedule-loop", "mlir::ModuleOp"> {
let summary = "test scheduling a loop for software pipelining";

let description = [{
This is a test pass that tests `scheduleLoop` method of `TritonGPULoopScheduling`.
}];

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
}

def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
let summary = "3xTF32 trick";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ static const char *kNumStagesAttrName = "tt.num_stages";
static const char *kLoopStageAttrName = "loop.stage";
static const char *kLoopClusterAttrName = "loop.cluster";

bool loopHasDistGreaterThanOne(scf::ForOp forOp);
bool isOuterLoop(scf::ForOp forOp);

/// Function to mask operations during scheduling.
Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred);

Expand Down
15 changes: 13 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
namespace mlir {
namespace triton {

namespace gpu {

/// Discover operations that should become async and assign latencies to them
/// based on the numStages value provided by the user.
DenseMap<Operation *, int> assignLatencies(ModuleOp forOp, int numStages);

/// Schedule the loop based on the latencies assigned to the operations.
void scheduleLoop(scf::ForOp forOp,
const DenseMap<Operation *, int> &opLatency);

}; // namespace gpu

/// This fill out the pipelining options including schedule and annotations
/// for wait ops. This also does pre-processing by converting some of the
/// loads into async loads so that the IR is ready to be pipelined.
Expand Down Expand Up @@ -108,8 +120,7 @@ class CoarseSchedule {

// Add dependencies of anchor ops to the coarse schedule. Schedule them to
// the same stage and ordering cluster as the anchor op.
void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule,
int numStages);
void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule);

} // namespace triton
} // namespace mlir
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ add_triton_library(TritonGPUTransforms
OptimizeAccumulatorInit.cpp
OptimizeDotOperands.cpp
OptimizeThreadLocality.cpp
Pipeliner/AssignLatencies.cpp
Pipeliner/MatmulLoopPipeline.cpp
Pipeliner/OuterLoopPipeline.cpp
Pipeliner/PipelineExpander.cpp
Pipeliner/TestPipelineAssignLatencies.cpp
Pipeliner/TestPipelineScheduleLoop.cpp
Pipeliner/SoftwarePipeliner.cpp
Pipeliner/TMAStoresPipeline.cpp
Pipeliner/PipeliningUtility.cpp
Expand Down
Loading

0 comments on commit 00cc5d0

Please sign in to comment.