From 4210274664ca70fd0ae88e05b89b5618891b0dda Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Tue, 26 Nov 2024 02:34:51 +0800 Subject: [PATCH] [AMD] NFC: Drop v2 Suffix from Stream Pipeline (#5251) Since StreamPipelineV2 has been the default for a while, this commit promoted StreamPipelineV2 to the general StreamPipeline by removing 'v2' suffix. --- bin/RegisterTritonDialects.h | 2 +- test/TritonGPU/amd/amd-instruction-sched.mlir | 10 +++++----- test/TritonGPU/loop-pipeline-hip.mlir | 2 +- test/TritonGPU/loop-pipeline.mlir | 4 ++-- third_party/amd/backend/compiler.py | 2 +- .../amd/include/TritonAMDGPUTransforms/Passes.h | 4 ++-- .../amd/include/TritonAMDGPUTransforms/Passes.td | 4 ++-- .../amd/lib/TritonAMDGPUTransforms/CMakeLists.txt | 2 +- .../{StreamPipelineV2.cpp => StreamPipeline.cpp} | 8 ++++---- third_party/amd/python/triton_amd.cc | 4 ++-- 10 files changed, 21 insertions(+), 21 deletions(-) rename third_party/amd/lib/TritonAMDGPUTransforms/{StreamPipelineV2.cpp => StreamPipeline.cpp} (99%) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 71d75b35dbf0..e873965e479a 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -62,7 +62,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUAccelerateMatmul(); mlir::registerTritonAMDGPUOptimizeEpilogue(); mlir::registerTritonAMDGPUReorderInstructions(); - mlir::registerTritonAMDGPUStreamPipelineV2(); + mlir::registerTritonAMDGPUStreamPipeline(); mlir::registerTritonAMDGPUCanonicalizePointers(); mlir::registerTritonAMDGPUConvertToBufferOps(); mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index 011116e1b201..b9f40dc29142 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -1,10 +1,10 @@ // RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 // RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 module { // INSERT_IGLP0-LABEL: @test_dot_op diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 6ca0897578b6..641ff165d32f 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index bb7e102c9074..5d0cc41a66bc 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK -// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD -// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2="num_stages=2 prefetch=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD_PREFETCH +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=2 prefetch=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD_PREFETCH // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index c222be2cd64d..0029c76ec3e9 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -241,7 +241,7 @@ def make_ttgir(mod, metadata, options): "num_stages == 0. Now it will not happen anymore; " "please update to use num_stages == 2 for " "equivalent behavior in the past.") - amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages, stream_prefetch) + amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, stream_prefetch) passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.insert_instruction_sched_hints(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 0a8d51bc8f44..630a1e903562 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -7,8 +7,8 @@ namespace mlir { -std::unique_ptr createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2, - int prefetch = 0); +std::unique_ptr createTritonAMDGPUStreamPipelinePass(int numStages = 2, + int prefetch = 0); std::unique_ptr createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 85604dcaca18..6bee6da5fb45 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -3,7 +3,7 @@ include "mlir/Pass/PassBase.td" -def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir::ModuleOp"> { +def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::ModuleOp"> { let summary = "pipeline"; let description = [{ @@ -11,7 +11,7 @@ def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir tile }]; - let constructor = "mlir::createTritonAMDGPUStreamPipelineV2Pass()"; + let constructor = "mlir::createTritonAMDGPUStreamPipelinePass()"; let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index c3a69a5f9a2a..aef5886b11d8 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -4,7 +4,7 @@ add_triton_library(TritonAMDGPUTransforms ConvertToBufferOps.cpp OptimizeEpilogue.cpp ReorderInstructions.cpp - StreamPipelineV2.cpp + StreamPipeline.cpp MfmaGroup.cpp DEPENDS diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp similarity index 99% rename from third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp rename to third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 2088fd80734c..79f956b06681 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -25,7 +25,7 @@ #define GEN_PASS_CLASSES #include "TritonAMDGPUTransforms/Passes.h.inc" -#define DEBUG_TYPE "tritonamdgpu-stream-pipeline-v2" +#define DEBUG_TYPE "tritonamdgpu-stream-pipeline" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") @@ -857,7 +857,7 @@ void labelLoadOpsForTritonDot(scf::ForOp forOp) { } } -struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { +struct PipelinePass : public TritonAMDGPUStreamPipelineBase { PipelinePass() = default; PipelinePass(int32_t numStages, int32_t prefetch) { this->numStages = numStages; @@ -893,7 +893,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { }; } // anonymous namespace -std::unique_ptr -mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages, int prefetch) { +std::unique_ptr mlir::createTritonAMDGPUStreamPipelinePass(int numStages, + int prefetch) { return std::make_unique(numStages, prefetch); } diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 3c335099104d..8132773fc2a1 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -72,8 +72,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUConvertToBufferOpsPass); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); - ADD_PASS_WRAPPER_2("add_stream_pipelinev2", - mlir::createTritonAMDGPUStreamPipelineV2Pass, int, int); + ADD_PASS_WRAPPER_2("add_stream_pipeline", + mlir::createTritonAMDGPUStreamPipelinePass, int, int); } void addControlConstant(llvm::Module *module, const char *name,