From 6556ec6050649e1fc42feb05a62ab9cc6908a722 Mon Sep 17 00:00:00 2001 From: Samuel Ginzburg Date: Thu, 16 Jan 2025 22:35:39 -0500 Subject: [PATCH] [AMD] Add support for buffer atomic RMW (#5549) # Overview This PR enables the raw.ptr.buffer.atomic.* RMW ops in the AMD backend. They feature similar calling conventions and semantics to the other buffer ops in the AMD backend. The new ops are gated behind the `AMDGCN_ENABLE_BUFFER_ATOMICS` environment variable which must be used in conjunction with `AMDGCN_USE_BUFFER_OPS`. They are also gated behind the GPU being CDNA3 (MI300-series GPUs) for now as the optimizations I added make assumptions regarding GFX942. I originally started exploratory work on the PR to better understand the comment in `LoadStoreOpToLLVM.cpp` referring to buffer atomics as "more efficient". In short I found that on their own they aren't necessarily more efficient, but using them in conjunction with more careful control over how cache coherence ops/memory fences are emitted can improve performance by a significant fraction. # How I've added a new buffer atomic RMW op in the AMDGPUOps dialect which has its own lowering in the backend. There are a number of checks in place to ensure that the lowering is done correctly between the ConvertToBufferOps pass and the LoadStoreOpToLLVM lowering. The actual lowering is where most of the performance gains come from. At a high-level, when non-buffer atomic RMW ops are emitted, the memory fences lower to something along the lines of: ```python buffer_wbl2 sc1 s_waitcnt lgkmcnt(0) atomicRMWop() s_waitcnt vmcnt(0) buffer_inv sc1 buffer_wbl2 sc1 s_waitcnt lgkmcnt(0) atomicRMWop() s_waitcnt vmcnt(0) buffer_inv sc1 ``` If my understanding of the [GFX942 memory model](https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942) is correct, then given several assumptions regarding CDNA3, this can actually be lowered to something that resembles: ```python buffer_wbl2 sc1 s_waitcnt lgkmcnt(0) atomicRMWop() s_waitcnt vmcnt(0) # AMDGCN specific cross-CU synchronization primitive atomicRMWop() s_waitcnt vmcnt(0) buffer_inv sc1 ``` There are comments in the code which explain the thought process for why (I think) that this is okay. It appears the AMD's CK library (AMD version of CUTLASS) uses similar synchronization mechanisms, although I am probably missing some of the context here for sure (https://github.com/ROCm/composable_kernel/blob/9e95d54cd2160dffc07c1197951a9ab1ca6c35f2/include/ck_tile/core/arch/amd_buffer_addressing.hpp#L619). # Results and Testing In addition to the added lit test, I ran the existing atomic rmw tests in tree with buffer ops + buffer atomics enabled and they appear to pass. Following this, I evaluated FP16 Split-K [gemm](https://github.com/pytorch-labs/tritonbench/blob/a2f668e38ec55978bfcf2a6a8d15294a5b9d3d36/tritonbench/operators/gemm/kernels/matmul.py#L190) with [llama shapes](https://github.com/pytorch-labs/tritonbench/blob/a2f668e38ec55978bfcf2a6a8d15294a5b9d3d36/tritonbench/utils/triton_op.py#L149) in tritonbench using an MI300x. Some minor modifications to the kernel were made to emit buffer ops (e.g., tl.assume calls). For testing purposes, I disabled the non split-k configurations. I also checked the numerical accuracy with rtol=atol=1e-4 for all shapes here. image Each bucket in the figure above corresponds to the average TFlops of all shapes with the same shared `M`-dim. At smaller batch sizes the performance is roughly equivalent. At BS=32, buffer atomics have ~50% greater TFlops. At BS=256 buffer atomics have ~3.75x the TFlops. Note: the purpose of this test is to evaluate the performance of buffer atomics---split-k is not always optimal for these shapes/workload etc... --- test/Conversion/amd/buffer_load_store.mlir | 37 ++ .../TritonGPU/amd/amd-convert-buffer-ops.mlir | 30 +- third_party/amd/backend/compiler.py | 3 +- .../include/Dialect/TritonAMDGPU/IR/Dialect.h | 1 + .../TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td | 1 + .../TritonAMDGPU/IR/TritonAMDGPUDialect.td | 2 +- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 49 ++- .../include/TritonAMDGPUTransforms/Passes.h | 3 +- .../include/TritonAMDGPUTransforms/Passes.td | 8 +- .../TritonAMDGPUToLLVM/BufferOpsEmitter.cpp | 81 ++++- .../lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h | 12 +- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 343 ++++++++++++++---- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 98 ++++- .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 24 ++ .../ConvertToBufferOps.cpp | 154 +++++++- third_party/amd/python/triton_amd.cc | 5 +- 16 files changed, 735 insertions(+), 116 deletions(-) diff --git a/test/Conversion/amd/buffer_load_store.mlir b/test/Conversion/amd/buffer_load_store.mlir index 39abbcb5839b..1ad1257cf00e 100644 --- a/test/Conversion/amd/buffer_load_store.mlir +++ b/test/Conversion/amd/buffer_load_store.mlir @@ -178,3 +178,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.return } } + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_atomic + tt.func @buffer_atomic_rmw_fadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}, %N: i32, %values : tensor<128xf32, #blocked0>) { + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %mask = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + // CHECK: %[[mask0:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // There should be a single release fence before any atomics + // CHECK: llvm.fence syncscope("agent") release + // CHECK: %[[mask1:.*]] = llvm.and %[[mask0]], {{.*}} + // CHECK: %[[offset:.*]] = llvm.select %[[mask1]] + + // We will have 4 calls to fadd, since the sizePerThread is 4. We should have a vmcnt between each call. + %ret = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %values, %arg0[%offset], %mask : tensor<128xf32, #blocked0> + + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + + // There should be a single acquire fence after all of the atomics + // CHECK: llvm.fence syncscope("agent") acquire + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index b1f13396639c..99e75b26cb5d 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops | FileCheck %s +// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops='arch-generation-name=gfx940'| FileCheck %s #blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { @@ -482,3 +482,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_positive_offset_buffer_atomic + tt.func @assume_positive_offset_buffer_atomic(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %sub = arith.subi %1, %c128_i32 : i32 + %cmp = arith.cmpi sgt, %sub, %c0_i32 : i32 + llvm.intr.assume %cmp : i1 + %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[offset:.*]] = arith.addi + %4 = arith.addi %2, %3 : tensor<1024xi32, #blocked> + // CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0 + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %6 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] + %8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked> + tt.return %8 : tensor<1024xf32, #blocked> + } +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 6019e01e8cc5..7115b2786d8f 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -220,6 +220,7 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_optimize_dot_operands(pm, True) stream_prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1" + use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1" # The `local-prefetch` scheduling variant requires turning on buffer ops. if options.instruction_sched_variant == "local-prefetch": @@ -247,7 +248,7 @@ def make_ttgir(mod, metadata, options): if use_buffer_ops: amd.passes.ttgpuir.add_canonicalize_pointers(pm) passes.common.add_canonicalizer(pm) - amd.passes.ttgpuir.add_convert_to_buffer_ops(pm) + amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index 845cca1fd8f2..8356b3cbc3e6 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -30,6 +30,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Traits.h" // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index 491989669c6a..44ac3c8d3aaf 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -25,6 +25,7 @@ #define TRITON_AMDGPU_ATTRDEFS include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "mlir/IR/EnumAttr.td" diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td index c0c18b07e907..91a3d3230bed 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td @@ -34,7 +34,7 @@ def TritonAMDGPU_Dialect : Dialect { TritonAMDGPU Dialect hosts AMD specific ops at TritonGPU abstraction level. }]; - let dependentDialects = []; + let dependentDialects = ["triton::TritonDialect"]; let useDefaultAttributePrinterParser = 1; let usePropertiesForAttributes = 1; diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index c8e3922d0254..a0110d122866 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -26,13 +26,16 @@ #define TRITON_AMDGPU_OPS include "mlir/IR/OpBase.td" +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" + include "mlir/IR/EnumAttr.td" -include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Pure -include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" @@ -209,6 +212,48 @@ def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [ }]; } +def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [ + SameLoadStoreOperandsAndResultEncoding, + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"result element type matches the value type", "result", "value", "$_self">, + TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, + TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, +]>{ + let summary = "Atomic RMW op which reads, modifies, and writes to a scalar base pointer and a tensor offset"; + let description = [{ + AMD Buffer atomic RMW operation. Buffer atomics are similar to normal atomics, but access global memory via a + scalar base pointer and a tensor of offsets instead of a tensor of pointers. + Similar to other buffer ops, the `mask` is a boolean vector that determines if a given element should be processed with + the atomic RMW op. Elements with `mask[i] == 0` are dropped (i.e., the atomic is not executed). + Similar to TT_AtomicRMWOp: Buffer atomic RMW ops load data at $ptr, do $rmw_op with $val, and store result to $ptr with + the specified memory semantics and scope. Atomic RMW ops return the pre-op value if used, otherwise the value is implicitly dropped. + }]; + let arguments = ( + ins + TT_AtomicRMWAttr:$atomic_rmw_op, + TT_Ptr:$ptr, + I32Tensor:$offsets, + TT_Tensor:$value, + TT_MemSemanticAttr:$sem, + TT_MemSyncScopeAttr:$scope, + Optional:$mask + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $value `,` $ptr `[` $offsets `]` (`,` $mask^)? + attr-dict `:` type($result) + }]; +} + + def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [ SameLoadStoreOperandsEncoding, MemoryEffects<[MemWrite]>, diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 6be2e87f7755..c375d2a386b2 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -26,7 +26,8 @@ std::unique_ptr createTritonAMDGPUOptimizeEpiloguePass(); std::unique_ptr createTritonAMDGPUCanonicalizePointersPass(); -std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass(); +std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass( + std::string archGenName = std::string()); std::unique_ptr createTritonAMDGPUBlockPingpongPass(); diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index bf25f7bf0c97..f026d1d59521 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -117,11 +117,17 @@ def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", " def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "mlir::ModuleOp"> { let summary = "Convert memory operations to buffer operations"; - let description = "This pass converts memory operations (e.g., tt.load/tt.store) to amdgpu buffer operations, if possible"; + let description = "This pass converts memory and atomic operations (e.g., tt.load/tt.store/tt.atomic_rmw) to amdgpu buffer operations, if possible"; let constructor = "mlir::createTritonAMDGPUConvertToBufferOpsPass()"; let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; + + let options = [ + Option<"archGenerationName", "arch-generation-name", + "std::string", /*default=*/"std::string{}", + "GFX generation name of target device.">, + ]; } def TritonAMDGPUBlockPingpong: Pass<"tritonamdgpu-block-pingpong", "mlir::ModuleOp"> { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp index 05a8e150903a..802e7ea703e3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -77,7 +77,7 @@ Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset, triton::CacheModifier cm) { SmallVector args; fillCommonArgs(type, rsrcDesc, offset, pred, cm, /*isBufferLoad=*/true, args); - Type bufferType = getBufferOpType(type); + Type bufferType = getBufferOpType(type, false); Value data = rewriter.create( loc, bufferType, args, ArrayRef()); data = bitcast(data, type); @@ -86,10 +86,34 @@ Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset, return data; } +Value BufferEmitter::emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, + Value offset, Value data, Value pred, + bool hasUsers) { + VectorType vecTy = cast(data.getType()); + Type bufferType = getBufferOpType(type, true); + if (vecTy != bufferType) + data = bitcast(data, bufferType); + + SmallVector args{data}; + fillCommonArgsAtomics(type, rsrcDesc, offset, pred, hasUsers, args); + + // TODO: + // The ops in ROCDL (e.g., RawPtrBufferAtomicFaddOp) have no return value, + // but they lower to instrinsics that can return values. This causes the + // LLVM verifier to fail. When this is fixed, the ROCDL ops should be used + // here. + auto rmwOpStr = stringifyRMWOp(rmwType).str(); + auto instrinsic = "llvm.amdgcn.raw.ptr.buffer.atomic." + rmwOpStr; + auto bufferAtomicRMW = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, instrinsic, bufferType, args); + + return bitcast(bufferAtomicRMW.getResult(0), type); +} + void BufferEmitter::emitStore(Value rsrcDesc, Value offset, Value data, Value pred, triton::CacheModifier cm) { VectorType vecTy = cast(data.getType()); - Type bufferType = getBufferOpType(vecTy); + Type bufferType = getBufferOpType(vecTy, false); if (vecTy != bufferType) data = bitcast(data, bufferType); SmallVector args{data}; @@ -99,7 +123,7 @@ void BufferEmitter::emitStore(Value rsrcDesc, Value offset, Value data, ArrayRef()); } -Type BufferEmitter::getBufferOpType(Type type) { +Type BufferEmitter::getBufferOpType(Type type, bool atomicsOp) { int64_t vecSize = 1; Type elementType = type; if (auto vecType = dyn_cast(type)) { @@ -110,16 +134,20 @@ Type BufferEmitter::getBufferOpType(Type type) { const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); const size_t totalWidthBits = valueElemNBits * vecSize; - // For bf16, always convert to i16 Type bufferElementType = elementType; - if (elementType.isBF16()) + // We don't want to cast from bf16 if we are emitting buffer atomics + if (elementType.isBF16() && !atomicsOp) { bufferElementType = rewriter.getI16Type(); + } // If we are dealing with a subword type (e.g., i8 or f16) but we // still need multiple words, then pack the subwords into 32bit integers // and update the vector length and the type + // We never need to pack for buffer atomics because we ensure + // 1) We can always emit a 32-bit / 64-bit atomics op + // 2) For tensors of 16-bit values that the values are contiguous int64_t bufferVecSize = vecSize; - if (valueElemNBits < 32) { + if (valueElemNBits < 32 && !atomicsOp) { if (totalWidthBits > 32) { bufferElementType = rewriter.getI32Type(); bufferVecSize = totalWidthBits / 32; @@ -166,10 +194,49 @@ void BufferEmitter::fillCommonArgs(Type type, Value rsrcDesc, getCtrlBitsForCacheModifierOnTarget(cm, isBufferLoad, targetInfo); Value cacheModifiers = int_val(32, aux); - // 5. Add the arguments + // 4. Add the arguments args.push_back(rsrcDesc); args.push_back(maskedOffsetBytes); args.push_back(sgprOffset); args.push_back(cacheModifiers); } + +void BufferEmitter::fillCommonArgsAtomics(Type type, Value rsrcDesc, + Value vOffsetElems, Value pred, + bool hasUsers, + SmallVector &args) { + + // 1. Create the (masked) offset + Type elementType = getElementTypeOrSelf(type); + const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); + const int elementByteWidth = valueElemNBits / 8; + // Please note: the index passed is not in bytes, but in number of elements + // In order to pass the index to the buffer operation, we need to convert in + // bytes (i.e., we need to multiply by `elementByteWidth`) + Value vOffsetOutOfBunds = int_val( + 32, static_cast(std::numeric_limits::max() + int64_t(1))); + Value vOffsetBytes = mul(int_val(32, elementByteWidth), vOffsetElems); + Value maskedOffsetBytes = select(pred, vOffsetBytes, vOffsetOutOfBunds); + + // 2. Set the sgprOffset to 0 + Value sgprOffset = int_val(32, 0); + + // 3. Create the cache modifiers word + int32_t aux = 0; + if (hasUsers) + aux = getCtrlBitsForBufferAtomicsOnGFX942(/*setSC0*/ true, /*setSC1*/ false, + /*setNT*/ false); + else + aux = getCtrlBitsForBufferAtomicsOnGFX942( + /*setSC0*/ false, /*setSC1*/ false, /*setNT*/ false); + + Value cacheModifiers = int_val(32, aux); + + // 4. Add the arguments + args.push_back(rsrcDesc); + args.push_back(maskedOffsetBytes); + args.push_back(sgprOffset); + args.push_back(cacheModifiers); +} + } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h index 29522f1f95ed..9c3633a499b3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h @@ -9,6 +9,7 @@ #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include namespace mlir::LLVM::AMD { // Utility class to take care of buffer operation emission. We may add more @@ -69,6 +70,10 @@ struct BufferEmitter { Value emitLoad(Type type, Value rsrcDesc, Value offset, Value pred, Value falseVal, CacheModifier cm); + // Emit a predicated rocdl.raw.ptr.buffer.atomic.* RMWOp + Value emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, Value offset, + Value data, Value pred, bool hasUsers); + // Emit a predicated rocdl.raw.ptr.buffer.store void emitStore(Value rsrcDesc, Value offset, Value data, Value pred, CacheModifier cm); @@ -79,10 +84,15 @@ struct BufferEmitter { CacheModifier cm, bool isBufferLoad, SmallVector &args); + // Fill buffer atomics arguments + void fillCommonArgsAtomics(Type type, Value rsrcDesc, Value vOffsetElems, + Value pred, bool hasUsers, + SmallVector &args); + // Given a type, the buffer type can be either the same type // or a packed version. E.g., a vector of 8xfp16 can be bitcasted to // a vector of 4xi32. This usually makes the life of the backend easier - Type getBufferOpType(Type type); + Type getBufferOpType(Type type, bool atomicsOp); // Rewriter utilities RewriterBase &rewriter; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 75fcf7bc93cc..2c6afdc9dbe8 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -20,6 +20,7 @@ using namespace mlir::triton::gpu; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::getSharedMemoryBase; +using ::mlir::LLVM::AMD::getVectorSize; using ::mlir::LLVM::AMD::llLoad; using ::mlir::LLVM::AMD::llStore; using ::mlir::triton::gpu::getTotalElemsPerThread; @@ -154,57 +155,6 @@ struct LoadStoreConversionBase { return offsetType.cloneWith(std::nullopt, basePtrType); } - // Get contiguity for a tensor pointer `ptr` - unsigned getContiguity(Value ptr) const { - auto tensorTy = dyn_cast(ptr.getType()); - if (!tensorTy) - return 1; - return axisAnalysisPass.getPtrContiguity(ptr); - } - - // Get contiguity for a scalar pointer `ptr` and a tensor `offset` - unsigned getContiguity(Value ptr, Value offset) const { - // Get contiguity from the offset - Type type = getPointerTypeWithShape(ptr, offset); - RankedTensorType tensorTy = cast(type); - auto layout = tensorTy.getEncoding(); - auto order = triton::gpu::getOrder(layout); - auto uniqueContigPerThread = - triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); - assert(order[0] < uniqueContigPerThread.size() && - "Unexpected uniqueContigPerThread size"); - unsigned contiguity = uniqueContigPerThread[order[0]]; - - // Get alignment from the pointer. Since this is a scalar pointer - // we should not take the pointer contiguity to consider alignment - auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr); - auto maxMultipleBytes = axisInfo->getDivisibility(0); - auto elemNumBits = triton::getPointeeBitWidth(tensorTy); - auto elemNumBytes = std::max(elemNumBits / 8, 1); - auto align = std::max(maxMultipleBytes / elemNumBytes, 1); - - // Final contiguity is a min of the offset contiguity and pointer alignment - contiguity = std::min(align, contiguity); - return contiguity; - } - - // Determine the vector size of a tensor of pointers - unsigned getVectorSize(Value ptr) const { - auto tensorTy = dyn_cast(ptr.getType()); - if (!tensorTy) - return 1; - auto contiguity = getContiguity(ptr); - auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); - return std::min(128 / pointeeBitWidth, contiguity); - } - - // Given a scalar pointer and a tensor of offsets, determine the vector size - unsigned getVectorSize(Value ptr, Value offset) const { - auto contiguity = getContiguity(ptr, offset); - auto pointeeBitWidth = triton::getPointeeBitWidth(ptr.getType()); - return std::min(128 / pointeeBitWidth, contiguity); - } - // Unpack the elements contained in a `llvmStruct` into a `SmallVector` of // `Value`s. While you do that, check also the alignment of the mask and // update the vector length `vec` accordingly @@ -288,7 +238,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, Type valueTy = op.getType(); Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); - unsigned vec = getVectorSize(ptr); + unsigned vec = getVectorSize(ptr, axisAnalysisPass); unsigned numElems = getTotalElemsPerThread(ptr.getType()); // Get the LLVM values for pointers @@ -336,7 +286,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, ptrAlignmentBytes, cacheMod); for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), ii); + rewriter, loc, getTypeConverter()->getIndexType(), ii); Value loaded = extract_element(valueElemTy, loadVal, vecIdx); loadedVals.push_back(loaded); } @@ -392,7 +342,7 @@ struct BufferLoadOpConversion typeConverter->convertType(getElementTypeOrSelf(valueTy)); Type ptrType = getPointerTypeWithShape(ptr, offset); unsigned numElems = getTotalElemsPerThread(ptrType); - unsigned vec = getVectorSize(ptr, offset); + unsigned vec = getVectorSize(ptr, offset, axisAnalysisPass); // Get the offset SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); @@ -422,7 +372,7 @@ struct BufferLoadOpConversion vecTy, rsrcDesc, offsetElems[vecStart], pred, falseVal, cacheMod); for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), ii); + rewriter, loc, getTypeConverter()->getIndexType(), ii); Value loaded = extract_element(valueElemTy, loadVal, vecIdx); loadedVals.push_back(loaded); } @@ -470,7 +420,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, typeConverter->convertType(getElementTypeOrSelf(valueTy)); // Determine the vectorization size - unsigned vec = getVectorSize(ptr); + unsigned vec = getVectorSize(ptr, axisAnalysisPass); unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); auto ptrElems = unpackLLElements(loc, llPtr, rewriter); @@ -514,6 +464,254 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, } }; +static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) { + switch (memOrdering) { + case MemSemantic::RELAXED: + return LLVM::AtomicOrdering::monotonic; + case MemSemantic::ACQUIRE: + return LLVM::AtomicOrdering::acquire; + case MemSemantic::RELEASE: + return LLVM::AtomicOrdering::release; + case MemSemantic::ACQUIRE_RELEASE: + return LLVM::AtomicOrdering::acq_rel; + default: + return LLVM::AtomicOrdering::acq_rel; + } +} + +struct BufferAtomicRMWOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern< + triton::amdgpu::BufferAtomicRMWOp>::ConvertOpToLLVMPattern; + + BufferAtomicRMWOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::BufferAtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + + // original values + Value ptr = op.getPtr(); + Value offset = op.getOffsets(); + Value mask = op.getMask(); + Value data = op.getValue(); + auto atomicRmwAttr = op.getAtomicRmwOp(); + + Value llPtr = adaptor.getPtr(); + Value llOffset = adaptor.getOffsets(); + Value llMask = adaptor.getMask(); + Value llData = adaptor.getValue(); + + // Determine the vectorization size + Type valueTy = data.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + Type ptrType = getPointerTypeWithShape(ptr, offset); + + unsigned numElems = getTotalElemsPerThread(ptrType); + unsigned vec = getVectorSize(ptr, offset, axisAnalysisPass); + + // v4f16 and v4bf16 variants of buffer atomics do not exist. + // only v2f16 and v2bf16. + if (valueElemTy.isBF16() || valueElemTy.isF16()) { + // We clamp to the only supported vectorization width here (2). + // In ConvertToBufferOps we check that we have a large enough vector size + assert(vec >= 2); + vec = 2u; + // The max width of a buffer atomic op is 64-bits + // Some types like F32 don't have a 2x vectorized version + } else if (valueElemTy.isF32() || valueElemTy.isF64() || + valueElemTy.isInteger(32) || valueElemTy.isInteger(64)) { + vec = 1u; + } + + // Get the offsets and value + SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); + SmallVector valueElems = unpackLLElements(loc, llData, rewriter); + + // Get the mask + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + // We need to manually emit memory fences (LLVM doesn't do this for buffer + // ops) see: https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942 + auto memOrdering = op.getSem(); + auto atomicMemOrdering = getMemoryOrdering(memOrdering); + auto rel = LLVM::AtomicOrdering::release; + auto acq = LLVM::AtomicOrdering::acquire; + + bool emitReleaseFence = false; + bool emitAcquireFence = false; + switch (memOrdering) { + case MemSemantic::RELAXED: + // In this case, no memory fences are needed + break; + case MemSemantic::RELEASE: + emitReleaseFence = true; + break; + case MemSemantic::ACQUIRE: + emitAcquireFence = true; + break; + case MemSemantic::ACQUIRE_RELEASE: + emitAcquireFence = true; + emitReleaseFence = true; + default: + // default == acq_rel, so we emit the same barriers + emitAcquireFence = true; + emitReleaseFence = true; + } + + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr); + Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + SmallVector loadedVals; + + // set the scope + auto memScope = op.getScope(); + auto scopeStr = ""; + switch (memScope) { + // System scope is not supported yet + case MemSyncScope::SYSTEM: + return failure(); + case MemSyncScope::GPU: + scopeStr = "agent"; + break; + case MemSyncScope::CTA: + scopeStr = "workgroup"; + break; + default: + return failure(); + } + + StringAttr scope = mlir::StringAttr::get(loc.getContext(), scopeStr); + + if (emitReleaseFence) + rewriter.create(loc, TypeRange{}, rel, scope); + + mlir::Operation *lastRMWOp; + MLIRContext *ctx = rewriter.getContext(); + GCNBuilder waitcntBuilder; + + // Triton supports three scopes for atomic access + // 1. System + // 2. GPU (default) + // 3. CTA (i.e., threadblock or warp-group) + // + // Currently, the AMD backend emits atomics with agent-scope. + // + // The following properties are used to emit proper synchronization + // primitives between sequential buffer atomics See: Memory Model GFX942 + // (MI300 series) + // https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942: + // + // buffer/global/flat_load/store/atomic instructions to global memory are + // termed vector memory operations. + // + // 1. Vector memory operations access a single vector L1 cache shared by + // all SIMDs a CU. + // No special action is required for coherence between wavefronts in the + // same work-group since they execute on the same CU. + // + // 2. Each CU has a separate request queue per channel for its associated + // L2. + // Therefore, the vector and scalar memory operations performed by + // wavefronts executing with different L1 caches and the same L2 cache + // can be reordered relative to each other. A `s_waitcnt vmcnt(0)` is + // required to ensure synchronization between vector memory operations of + // different CUs. It ensures a previous vector memory operation has + // completed before executing a subsequent vector memory or LDS operation + // and so can be used to meet the requirements of acquire and release. + // + // 3. Atomic read-modify-write instructions implicitly bypass the L1 cache + // (specific to gfx942) + // Therefore, they do not use the sc0 bit for coherence and instead use + // it to indicate if the instruction returns the original value being + // updated. They do use sc1 to indicate system or agent scope coherence. + // See the cache modifiers word in BufferEmitter::fillCommonArgs for + // more details. + // + // In summary: + // 1. We have to emit memory fences (i.e., acq/rel/acq_rel) before and after + // our buffer atomics. + // 2. Because buffer atomic rmw ops skip the l1 cache, s_waitcnt vmcnt(0) is + // sufficient for synchronization between instructions. + // We don't need to invalidate L1 between these ops on GFX942, just after + // (i.e., we can skip `buffer_wbinvl1_vol`) + // 3. We don't have to explicitly write to the l2 cache because + // `s_waitcnt vmcnt(0)` already does this as-per the MI300/CDNA3 ISA + // docs: "Decremented for reads when the data has been written back to + // the VGPRs, and for writes when the data has been written to the L2 + // cache. Ordering: Memory reads and writes return in the order they were + // issued, including mixing reads and writes" + // 4. We set GLC=1, to return the old value. Atomics in GFX942 execute with + // either device (default) or system scope (controlled by the sc1 flag). + // This is distinct from the memory scope of the atomic (i.e, the memory + // fences which appear before/after the ops). + + if (memScope == MemSyncScope::GPU) { + waitcntBuilder.create<>("s_waitcnt vmcnt(0)")->operator()(); + } else if (memScope == MemSyncScope::CTA) { + // TODO: Within a CTA we can possibly relax this? + waitcntBuilder.create<>("s_waitcnt vmcnt(0)")->operator()(); + } + + // Check if the op has users, if it does we set GLC=1, otherwise GLC=0 + auto opUsers = op.getResult().getUsers(); + auto hasUsers = std::distance(opUsers.begin(), opUsers.end()) > 0; + + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask; + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); + // Create the store val + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + valueElems, vecStart); + + Value loadVal = bufferEmitter.emitAtomicRMW( + atomicRmwAttr, vecTy, rsrcDesc, offsetElems[vecStart], storeVal, pred, + hasUsers); + // Track the last op, so we can emit a fenceop after the loop + lastRMWOp = loadVal.getDefiningOp(); + + // To sync vector memory ops between CUs within an agent, we need an + // s_waitcnt skip doing this on the last iteration of the loop + // In the relaxed memory ordering, we don't need this barrier + if (vecStart < numElems - vec && (emitReleaseFence || emitAcquireFence)) { + Value inst = + waitcntBuilder.launch(rewriter, lastRMWOp->getLoc(), void_ty(ctx)); + lastRMWOp = inst.getDefiningOp(); + } + for (size_t ii = 0; ii < vec; ++ii) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), ii); + Value loaded = extract_element(valueElemTy, loadVal, vecIdx); + loadedVals.push_back(loaded); + } + } // end vec + + // Acquire Fence post-atomic + if (emitAcquireFence) + rewriter.create(lastRMWOp->getLoc(), TypeRange{}, acq, + scope); + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, + rewriter, llvmResultStructTy); + + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + struct BufferStoreOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { @@ -553,7 +751,7 @@ struct BufferStoreOpConversion Type ptrType = getPointerTypeWithShape(ptr, offset); unsigned numElems = getTotalElemsPerThread(ptrType); - unsigned vec = getVectorSize(ptr, offset); + unsigned vec = getVectorSize(ptr, offset, axisAnalysisPass); // Get the offsets and value SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); @@ -581,21 +779,6 @@ struct BufferStoreOpConversion } }; -static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) { - switch (memOrdering) { - case MemSemantic::RELAXED: - return LLVM::AtomicOrdering::monotonic; - case MemSemantic::ACQUIRE: - return LLVM::AtomicOrdering::acquire; - case MemSemantic::RELEASE: - return LLVM::AtomicOrdering::release; - case MemSemantic::ACQUIRE_RELEASE: - return LLVM::AtomicOrdering::acq_rel; - default: - return LLVM::AtomicOrdering::acq_rel; - } -} - struct AtomicCASOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { @@ -641,7 +824,7 @@ struct AtomicCASOpConversion auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); // vec = 1 for scalar - auto vec = getVectorSize(op.getPtr()); + auto vec = getVectorSize(op.getPtr(), axisAnalysisPass); // tensor if (TensorTy) { auto valTy = cast(op.getVal().getType()); @@ -839,7 +1022,7 @@ struct AtomicRMWOpConversion const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); // vec = 1, numElements = 1 for scalar - auto vec = getVectorSize(ptr); + auto vec = getVectorSize(ptr, axisAnalysisPass); int numElems = 1; Type packF16Ty = vec_ty(valueElemTy, 2); @@ -940,13 +1123,13 @@ struct AtomicRMWOpConversion rewriter.setInsertionPointToEnd(atomicBlock); auto maybeKind = matchAtomicOp(atomicRmwAttr); - // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient - // atomics for MI-* series of AMD GPU. + Value atom = rewriter .create(loc, *maybeKind, rmwPtr, operand, atomicMemOrdering, StringRef(scopeStr.value())) .getResult(); + if (!tensorTy) { if (atomicNeedsSharedMemory(op.getResult())) { Value atomPtr = @@ -1009,9 +1192,9 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns - .add( - typeConverter, targetInfo, axisInfoAnalysis, benefit); + patterns.add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 0766b611ee88..ad7d892b3270 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -6,6 +6,7 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +using mlir::triton::ModuleAxisInfoAnalysis; using mlir::triton::AMD::DppCtrl; using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; @@ -438,19 +439,22 @@ getCacheModifierFlagsForPredicatedCall(LLVM::CallOp callOp) { // - SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system // - NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse // -// -----+-----+-----+-----+----+-- -// Op | cm | SC1 | SC0 | NT | -// -----+-----+-----+-----+----+-- -// Load | .ca | 0 | 0 | 0 | -// | .cg | 0 | 1 | 1 | -// | .cs | 0 | 1 | 1 | -// | .cv | 1 | 1 | x | -// -----+-----+-----+-----+----+-- -// Store| .wb | 0 | 0 | 0 | -// | .cg | 0 | 0 | 0 | -// | .cs | 0 | 1 | 1 | -// | .wt | 1 | x | x | -// -----+-----+-----+-----+----+-- +// -------+-----+-----+-----+----+-- +// Op | cm | SC1 | SC0 | NT | +// -------+-----+-----+-----+----+-- +// Load | .ca | 0 | 0 | 0 | +// | .cg | 0 | 1 | 1 | +// | .cs | 0 | 1 | 1 | +// | .cv | 1 | 1 | x | +// -------+-----+-----+-----+----+-- +// Store | .wb | 0 | 0 | 0 | +// | .cg | 0 | 0 | 0 | +// | .cs | 0 | 1 | 1 | +// | .wt | 1 | x | x | +// -------+-----+-----+-----+----+-- +// Atomic | N/A | 0 | 1 | x | Setting sc0 returns the pre-op value +// | N/A | 1 | 0 | x | Setting sc1 performs a system-scope atomic +// -------+-----+-----+-----+----+-- static int32_t getCtrlBitsForCacheModifierOnGFX942(triton::CacheModifier cm, bool isBufferLoad) { const int sc0Bit = 0b1, ntBit = 0b10, sc1Bit = 0b1000; @@ -481,6 +485,19 @@ static int32_t getCtrlBitsForCacheModifierOnGFX942(triton::CacheModifier cm, return aux; } +int32_t getCtrlBitsForBufferAtomicsOnGFX942(bool setSC0, bool setSC1, + bool setNT) { + const int sc0Bit = 0b1, ntBit = 0b10, sc1Bit = 0b1000; + int32_t aux = 0; + if (setSC0) + aux |= sc0Bit; + if (setSC1) + aux |= sc1Bit; + if (setNT) + aux |= ntBit; + return aux; +} + static int32_t getDefaultCtrlBitsForCacheModifier(triton::CacheModifier cm) { return 0; } @@ -520,4 +537,59 @@ Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v, return builder.launch(rewriter, loc, f16_ty, false); } +Type getPointerTypeWithShape(Value basePtr, Value offset) { + Type basePtrType = basePtr.getType(); + auto offsetType = cast(offset.getType()); + return offsetType.cloneWith(std::nullopt, basePtrType); +} + +unsigned getContiguity(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + return axisAnalysisPass.getPtrContiguity(ptr); +} + +unsigned getContiguity(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass) { + // Get contiguity from the offset + Type type = getPointerTypeWithShape(ptr, offset); + RankedTensorType tensorTy = cast(type); + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + auto uniqueContigPerThread = + triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); + assert(order[0] < uniqueContigPerThread.size() && + "Unexpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + + // Get alignment from the pointer. Since this is a scalar pointer + // we should not take the pointer contiguity to consider alignment + auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr); + auto maxMultipleBytes = axisInfo->getDivisibility(0); + auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + auto elemNumBytes = std::max(elemNumBits / 8, 1); + auto align = std::max(maxMultipleBytes / elemNumBytes, 1); + + // Final contiguity is a min of the offset contiguity and pointer alignment + contiguity = std::min(align, contiguity); + return contiguity; +} + +unsigned getVectorSize(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto contiguity = getContiguity(ptr, axisAnalysisPass); + auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); + return std::min(128 / pointeeBitWidth, contiguity); +} + +unsigned getVectorSize(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass) { + auto contiguity = getContiguity(ptr, offset, axisAnalysisPass); + auto pointeeBitWidth = triton::getPointeeBitWidth(ptr.getType()); + return std::min(128 / pointeeBitWidth, contiguity); +} + } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index a79913262a49..1dabe31db2d9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -7,6 +7,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -56,8 +57,31 @@ std::pair getCacheModifierFlagsForPredicatedCall(LLVM::CallOp); int32_t getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier, bool, mlir::triton::AMD::TargetInfo &); +// Get cache modifier information for buffer atomics +int32_t getCtrlBitsForBufferAtomicsOnGFX942(bool setSC0, bool setSC1, + bool setNT); + Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v, triton::RoundingMode rounding); + +// Return a tensor of pointers with the same type of `basePtr` and the same +// shape of `offset` +Type getPointerTypeWithShape(Value basePtr, Value offset); + +// Get contiguity for a tensor pointer `ptr` +unsigned getContiguity(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass); + +// Get contiguity for a scalar pointer `ptr` and a tensor `offset` +unsigned getContiguity(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass); + +// Determine the vector size of a tensor of pointers +unsigned getVectorSize(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass); + +// Given a scalar pointer and a tensor of offsets, determine the vector size +unsigned getVectorSize(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass); + } // namespace mlir::LLVM::AMD #endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index b9e2aaf3668e..9378596e360d 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -12,6 +12,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h" +#include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -23,11 +25,15 @@ #define GEN_PASS_CLASSES #include "TritonAMDGPUTransforms/Passes.h" +#undef DEBUG_TYPE #define DEBUG_TYPE "tritonamdgpu-convert-buffer-ops" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; +using ::mlir::LLVM::AMD::getVectorSize; +using mlir::triton::AMD::ISAFamily; + namespace ttg = mlir::triton::gpu; namespace tt = mlir::triton; @@ -227,6 +233,129 @@ bool canUseBufferOps(Value ptr, const DenseSet &assumptions) { } } // namespace +struct ConvertTritonAtomicRMWOpToBufferAtomicRMW + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertTritonAtomicRMWOpToBufferAtomicRMW( + mlir::MLIRContext *context, DenseSet &assumptions, + ModuleAxisInfoAnalysis &axisAnalysisPass) + : mlir::OpRewritePattern(context), + assumptions(assumptions), axisAnalysisPass(axisAnalysisPass) {} + + mlir::LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override { + LDBG("Try to convert: " << op); + Value ptr = op.getPtr(); + auto atomicRmwOp = op.getAtomicRmwOp(); + auto sem = op.getSem(); + auto scope = op.getScope(); + + // In addition to the `canUserBufferOps` check, we should ensure that + // 1. Perform the canUserBufferOps check + if (!canUseBufferOps(ptr, assumptions)) { + return rewriter.notifyMatchFailure(op, "canUseBufferOps check failed"); + } + + // 2. Check the scope. We support GPU and CTA for now (SYSTEM scope is not + // supported yet) + switch (scope) { + case MemSyncScope::GPU: + case MemSyncScope::CTA: + break; + default: + return rewriter.notifyMatchFailure(op, "RMW with unsupported scope"); + } + LDBG("RMW supported scope"); + + // 3. Check the memory ordering. + // TODO: support monotonic + switch (sem) { + case MemSemantic::RELAXED: + case MemSemantic::RELEASE: + case MemSemantic::ACQUIRE: + case MemSemantic::ACQUIRE_RELEASE: + break; + default: + return rewriter.notifyMatchFailure( + op, "RMW with unsupported memory ordering"); + } + + auto addPtrOp = ptr.getDefiningOp(); + Value tensorPtr = addPtrOp.getPtr(); + Value tensorOffset = addPtrOp.getOffset(); + auto splatOp = tensorPtr.getDefiningOp(); + Value basePtr = splatOp.getSrc(); + + // 4. Buffer atomic RMW does not support FP8 ops + // easier to just check what we support + auto checkType = getElementTypeOrSelf(op.getVal()); + bool isSupportedType = checkType.isF16() || checkType.isBF16() || + checkType.isF32() || checkType.isF64() || + checkType.isInteger(32) || checkType.isInteger(64); + if (!isSupportedType) { + return rewriter.notifyMatchFailure(op, "RMW with unsupported type"); + } + LDBG("RMW supported type"); + + // 5. Check if the RMWOp is supported + switch (atomicRmwOp) { + case RMWOp::AND: + case RMWOp::OR: + case RMWOp::XOR: + case RMWOp::ADD: + case RMWOp::FADD: + case RMWOp::MAX: + case RMWOp::MIN: + case RMWOp::UMAX: + case RMWOp::UMIN: + case RMWOp::XCHG: + break; + default: + auto rmwOpStr = stringifyRMWOp(atomicRmwOp).str(); + return rewriter.notifyMatchFailure(op, "RMW with unsupported op: " + + rmwOpStr); + } + LDBG("RMW supported Op"); + + // 6. Buffer atomics support 32 and 64-bit operations, so inputs must be at + // least 32-bits. Otherwise, fall back to the existing path for atomics + auto opValueType = op.getVal().getType(); + auto opBitWidth = 0; + if (auto tensorType = dyn_cast(opValueType)) { + // We can't just compute the opBitWidth using the numElements * + // elemBitWidth here. In cases such as tensor<2xf16...>, if the elements + // are contiguous we can emit the buffer op. Otherwise, the buffer ops + // lowering will try to emit individual (unsupported) f16/bf16 ops. + auto elemBitWidth = tensorType.getElementTypeBitWidth(); + opBitWidth = + getVectorSize(basePtr, tensorOffset, axisAnalysisPass) * elemBitWidth; + } else { + opBitWidth = opValueType.getIntOrFloatBitWidth(); + } + + if (opBitWidth < 32) { + return rewriter.notifyMatchFailure(op, "RMW requires opBitWidth >= 32"); + } + + Value maybeMask{}; + if (op.getMask() && !isZeroConst(op.getMask())) + maybeMask = op.getMask(); + + rewriter.replaceOpWithNewOp( + op, op.getVal().getType(), atomicRmwOp, basePtr, tensorOffset, + op.getVal(), sem, scope, maybeMask); + + return success(); + } + +private: + // Assumptions collected through the function + DenseSet assumptions; + ModuleAxisInfoAnalysis &axisAnalysisPass; +}; + struct ConvertTritonLoadToBufferLoad : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -267,7 +396,6 @@ struct ConvertTritonLoadToBufferLoad opIdxAttr); } rewriter.replaceOp(op, bufferLoadOp); - return success(); } LDBG("Failed to convert: " << op); @@ -322,13 +450,17 @@ class TritonAMDGPUConvertToBufferOpsPass public: TritonAMDGPUConvertToBufferOpsPass() = default; + TritonAMDGPUConvertToBufferOpsPass(StringRef archGen) { + this->archGenerationName = archGen.data(); + }; void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - ModuleOp m = getOperation(); + ModuleOp mod = getOperation(); + // Collect assumptions in the function DenseSet assumptions; - m.walk([&](LLVM::AssumeOp op) { + mod.walk([&](LLVM::AssumeOp op) { if (op->getOperand(0).getDefiningOp()) assumptions.insert(op->getOperand(0)); }); @@ -337,13 +469,23 @@ class TritonAMDGPUConvertToBufferOpsPass LDBG("Assumption:" << assume); } + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); patterns.add(context, assumptions); patterns.add(context, assumptions); - if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + + // Gate buffer atomics behind CDNA3 (i.e., MI300 series) for now + // GFX942-specific assumptions regarding cache coherence are made when + // lowering to LLVM + if (ISAFamily::CDNA3 == triton::AMD::deduceISAFamily(archGenerationName)) + patterns.add( + context, assumptions, axisInfoAnalysis); + + if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); } }; -std::unique_ptr mlir::createTritonAMDGPUConvertToBufferOpsPass() { - return std::make_unique(); +std::unique_ptr +mlir::createTritonAMDGPUConvertToBufferOpsPass(std::string archGen) { + return std::make_unique(archGen); } diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 1e999d631b3f..9e27e06b97e9 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -68,8 +68,9 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUOptimizeEpiloguePass); ADD_PASS_WRAPPER_0("add_canonicalize_pointers", mlir::createTritonAMDGPUCanonicalizePointersPass); - ADD_PASS_WRAPPER_0("add_convert_to_buffer_ops", - mlir::createTritonAMDGPUConvertToBufferOpsPass); + ADD_PASS_WRAPPER_1("add_convert_to_buffer_ops", + mlir::createTritonAMDGPUConvertToBufferOpsPass, + const std::string &); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); ADD_PASS_WRAPPER_0("add_block_pingpong",