diff --git a/test/Conversion/amd/buffer_load_store.mlir b/test/Conversion/amd/buffer_load_store.mlir index 39abbcb5839b..1ad1257cf00e 100644 --- a/test/Conversion/amd/buffer_load_store.mlir +++ b/test/Conversion/amd/buffer_load_store.mlir @@ -178,3 +178,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.return } } + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_atomic + tt.func @buffer_atomic_rmw_fadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}, %N: i32, %values : tensor<128xf32, #blocked0>) { + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %mask = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + // CHECK: %[[mask0:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // There should be a single release fence before any atomics + // CHECK: llvm.fence syncscope("agent") release + // CHECK: %[[mask1:.*]] = llvm.and %[[mask0]], {{.*}} + // CHECK: %[[offset:.*]] = llvm.select %[[mask1]] + + // We will have 4 calls to fadd, since the sizePerThread is 4. We should have a vmcnt between each call. + %ret = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %values, %arg0[%offset], %mask : tensor<128xf32, #blocked0> + + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + + // There should be a single acquire fence after all of the atomics + // CHECK: llvm.fence syncscope("agent") acquire + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index b1f13396639c..99e75b26cb5d 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops | FileCheck %s +// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops='arch-generation-name=gfx940'| FileCheck %s #blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { @@ -482,3 +482,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_positive_offset_buffer_atomic + tt.func @assume_positive_offset_buffer_atomic(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %sub = arith.subi %1, %c128_i32 : i32 + %cmp = arith.cmpi sgt, %sub, %c0_i32 : i32 + llvm.intr.assume %cmp : i1 + %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[offset:.*]] = arith.addi + %4 = arith.addi %2, %3 : tensor<1024xi32, #blocked> + // CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0 + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %6 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] + %8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked> + tt.return %8 : tensor<1024xf32, #blocked> + } +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 6019e01e8cc5..7115b2786d8f 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -220,6 +220,7 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_optimize_dot_operands(pm, True) stream_prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1" + use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1" # The `local-prefetch` scheduling variant requires turning on buffer ops. if options.instruction_sched_variant == "local-prefetch": @@ -247,7 +248,7 @@ def make_ttgir(mod, metadata, options): if use_buffer_ops: amd.passes.ttgpuir.add_canonicalize_pointers(pm) passes.common.add_canonicalizer(pm) - amd.passes.ttgpuir.add_convert_to_buffer_ops(pm) + amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index 845cca1fd8f2..8356b3cbc3e6 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -30,6 +30,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Traits.h" // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index 491989669c6a..44ac3c8d3aaf 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -25,6 +25,7 @@ #define TRITON_AMDGPU_ATTRDEFS include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "mlir/IR/EnumAttr.td" diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td index c0c18b07e907..91a3d3230bed 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td @@ -34,7 +34,7 @@ def TritonAMDGPU_Dialect : Dialect { TritonAMDGPU Dialect hosts AMD specific ops at TritonGPU abstraction level. }]; - let dependentDialects = []; + let dependentDialects = ["triton::TritonDialect"]; let useDefaultAttributePrinterParser = 1; let usePropertiesForAttributes = 1; diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index c8e3922d0254..a0110d122866 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -26,13 +26,16 @@ #define TRITON_AMDGPU_OPS include "mlir/IR/OpBase.td" +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" + include "mlir/IR/EnumAttr.td" -include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Pure -include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" @@ -209,6 +212,48 @@ def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [ }]; } +def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [ + SameLoadStoreOperandsAndResultEncoding, + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"result element type matches the value type", "result", "value", "$_self">, + TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, + TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, +]>{ + let summary = "Atomic RMW op which reads, modifies, and writes to a scalar base pointer and a tensor offset"; + let description = [{ + AMD Buffer atomic RMW operation. Buffer atomics are similar to normal atomics, but access global memory via a + scalar base pointer and a tensor of offsets instead of a tensor of pointers. + Similar to other buffer ops, the `mask` is a boolean vector that determines if a given element should be processed with + the atomic RMW op. Elements with `mask[i] == 0` are dropped (i.e., the atomic is not executed). + Similar to TT_AtomicRMWOp: Buffer atomic RMW ops load data at $ptr, do $rmw_op with $val, and store result to $ptr with + the specified memory semantics and scope. Atomic RMW ops return the pre-op value if used, otherwise the value is implicitly dropped. + }]; + let arguments = ( + ins + TT_AtomicRMWAttr:$atomic_rmw_op, + TT_Ptr:$ptr, + I32Tensor:$offsets, + TT_Tensor:$value, + TT_MemSemanticAttr:$sem, + TT_MemSyncScopeAttr:$scope, + Optional:$mask + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $value `,` $ptr `[` $offsets `]` (`,` $mask^)? + attr-dict `:` type($result) + }]; +} + + def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [ SameLoadStoreOperandsEncoding, MemoryEffects<[MemWrite]>, diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 6be2e87f7755..c375d2a386b2 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -26,7 +26,8 @@ std::unique_ptr createTritonAMDGPUOptimizeEpiloguePass(); std::unique_ptr createTritonAMDGPUCanonicalizePointersPass(); -std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass(); +std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass( + std::string archGenName = std::string()); std::unique_ptr createTritonAMDGPUBlockPingpongPass(); diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index bf25f7bf0c97..f026d1d59521 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -117,11 +117,17 @@ def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", " def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "mlir::ModuleOp"> { let summary = "Convert memory operations to buffer operations"; - let description = "This pass converts memory operations (e.g., tt.load/tt.store) to amdgpu buffer operations, if possible"; + let description = "This pass converts memory and atomic operations (e.g., tt.load/tt.store/tt.atomic_rmw) to amdgpu buffer operations, if possible"; let constructor = "mlir::createTritonAMDGPUConvertToBufferOpsPass()"; let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; + + let options = [ + Option<"archGenerationName", "arch-generation-name", + "std::string", /*default=*/"std::string{}", + "GFX generation name of target device.">, + ]; } def TritonAMDGPUBlockPingpong: Pass<"tritonamdgpu-block-pingpong", "mlir::ModuleOp"> { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp index 05a8e150903a..802e7ea703e3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -77,7 +77,7 @@ Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset, triton::CacheModifier cm) { SmallVector args; fillCommonArgs(type, rsrcDesc, offset, pred, cm, /*isBufferLoad=*/true, args); - Type bufferType = getBufferOpType(type); + Type bufferType = getBufferOpType(type, false); Value data = rewriter.create( loc, bufferType, args, ArrayRef()); data = bitcast(data, type); @@ -86,10 +86,34 @@ Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset, return data; } +Value BufferEmitter::emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, + Value offset, Value data, Value pred, + bool hasUsers) { + VectorType vecTy = cast(data.getType()); + Type bufferType = getBufferOpType(type, true); + if (vecTy != bufferType) + data = bitcast(data, bufferType); + + SmallVector args{data}; + fillCommonArgsAtomics(type, rsrcDesc, offset, pred, hasUsers, args); + + // TODO: + // The ops in ROCDL (e.g., RawPtrBufferAtomicFaddOp) have no return value, + // but they lower to instrinsics that can return values. This causes the + // LLVM verifier to fail. When this is fixed, the ROCDL ops should be used + // here. + auto rmwOpStr = stringifyRMWOp(rmwType).str(); + auto instrinsic = "llvm.amdgcn.raw.ptr.buffer.atomic." + rmwOpStr; + auto bufferAtomicRMW = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, instrinsic, bufferType, args); + + return bitcast(bufferAtomicRMW.getResult(0), type); +} + void BufferEmitter::emitStore(Value rsrcDesc, Value offset, Value data, Value pred, triton::CacheModifier cm) { VectorType vecTy = cast(data.getType()); - Type bufferType = getBufferOpType(vecTy); + Type bufferType = getBufferOpType(vecTy, false); if (vecTy != bufferType) data = bitcast(data, bufferType); SmallVector args{data}; @@ -99,7 +123,7 @@ void BufferEmitter::emitStore(Value rsrcDesc, Value offset, Value data, ArrayRef()); } -Type BufferEmitter::getBufferOpType(Type type) { +Type BufferEmitter::getBufferOpType(Type type, bool atomicsOp) { int64_t vecSize = 1; Type elementType = type; if (auto vecType = dyn_cast(type)) { @@ -110,16 +134,20 @@ Type BufferEmitter::getBufferOpType(Type type) { const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); const size_t totalWidthBits = valueElemNBits * vecSize; - // For bf16, always convert to i16 Type bufferElementType = elementType; - if (elementType.isBF16()) + // We don't want to cast from bf16 if we are emitting buffer atomics + if (elementType.isBF16() && !atomicsOp) { bufferElementType = rewriter.getI16Type(); + } // If we are dealing with a subword type (e.g., i8 or f16) but we // still need multiple words, then pack the subwords into 32bit integers // and update the vector length and the type + // We never need to pack for buffer atomics because we ensure + // 1) We can always emit a 32-bit / 64-bit atomics op + // 2) For tensors of 16-bit values that the values are contiguous int64_t bufferVecSize = vecSize; - if (valueElemNBits < 32) { + if (valueElemNBits < 32 && !atomicsOp) { if (totalWidthBits > 32) { bufferElementType = rewriter.getI32Type(); bufferVecSize = totalWidthBits / 32; @@ -166,10 +194,49 @@ void BufferEmitter::fillCommonArgs(Type type, Value rsrcDesc, getCtrlBitsForCacheModifierOnTarget(cm, isBufferLoad, targetInfo); Value cacheModifiers = int_val(32, aux); - // 5. Add the arguments + // 4. Add the arguments args.push_back(rsrcDesc); args.push_back(maskedOffsetBytes); args.push_back(sgprOffset); args.push_back(cacheModifiers); } + +void BufferEmitter::fillCommonArgsAtomics(Type type, Value rsrcDesc, + Value vOffsetElems, Value pred, + bool hasUsers, + SmallVector &args) { + + // 1. Create the (masked) offset + Type elementType = getElementTypeOrSelf(type); + const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); + const int elementByteWidth = valueElemNBits / 8; + // Please note: the index passed is not in bytes, but in number of elements + // In order to pass the index to the buffer operation, we need to convert in + // bytes (i.e., we need to multiply by `elementByteWidth`) + Value vOffsetOutOfBunds = int_val( + 32, static_cast(std::numeric_limits::max() + int64_t(1))); + Value vOffsetBytes = mul(int_val(32, elementByteWidth), vOffsetElems); + Value maskedOffsetBytes = select(pred, vOffsetBytes, vOffsetOutOfBunds); + + // 2. Set the sgprOffset to 0 + Value sgprOffset = int_val(32, 0); + + // 3. Create the cache modifiers word + int32_t aux = 0; + if (hasUsers) + aux = getCtrlBitsForBufferAtomicsOnGFX942(/*setSC0*/ true, /*setSC1*/ false, + /*setNT*/ false); + else + aux = getCtrlBitsForBufferAtomicsOnGFX942( + /*setSC0*/ false, /*setSC1*/ false, /*setNT*/ false); + + Value cacheModifiers = int_val(32, aux); + + // 4. Add the arguments + args.push_back(rsrcDesc); + args.push_back(maskedOffsetBytes); + args.push_back(sgprOffset); + args.push_back(cacheModifiers); +} + } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h index 29522f1f95ed..9c3633a499b3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h @@ -9,6 +9,7 @@ #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include namespace mlir::LLVM::AMD { // Utility class to take care of buffer operation emission. We may add more @@ -69,6 +70,10 @@ struct BufferEmitter { Value emitLoad(Type type, Value rsrcDesc, Value offset, Value pred, Value falseVal, CacheModifier cm); + // Emit a predicated rocdl.raw.ptr.buffer.atomic.* RMWOp + Value emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, Value offset, + Value data, Value pred, bool hasUsers); + // Emit a predicated rocdl.raw.ptr.buffer.store void emitStore(Value rsrcDesc, Value offset, Value data, Value pred, CacheModifier cm); @@ -79,10 +84,15 @@ struct BufferEmitter { CacheModifier cm, bool isBufferLoad, SmallVector &args); + // Fill buffer atomics arguments + void fillCommonArgsAtomics(Type type, Value rsrcDesc, Value vOffsetElems, + Value pred, bool hasUsers, + SmallVector &args); + // Given a type, the buffer type can be either the same type // or a packed version. E.g., a vector of 8xfp16 can be bitcasted to // a vector of 4xi32. This usually makes the life of the backend easier - Type getBufferOpType(Type type); + Type getBufferOpType(Type type, bool atomicsOp); // Rewriter utilities RewriterBase &rewriter; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 75fcf7bc93cc..2c6afdc9dbe8 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -20,6 +20,7 @@ using namespace mlir::triton::gpu; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::getSharedMemoryBase; +using ::mlir::LLVM::AMD::getVectorSize; using ::mlir::LLVM::AMD::llLoad; using ::mlir::LLVM::AMD::llStore; using ::mlir::triton::gpu::getTotalElemsPerThread; @@ -154,57 +155,6 @@ struct LoadStoreConversionBase { return offsetType.cloneWith(std::nullopt, basePtrType); } - // Get contiguity for a tensor pointer `ptr` - unsigned getContiguity(Value ptr) const { - auto tensorTy = dyn_cast(ptr.getType()); - if (!tensorTy) - return 1; - return axisAnalysisPass.getPtrContiguity(ptr); - } - - // Get contiguity for a scalar pointer `ptr` and a tensor `offset` - unsigned getContiguity(Value ptr, Value offset) const { - // Get contiguity from the offset - Type type = getPointerTypeWithShape(ptr, offset); - RankedTensorType tensorTy = cast(type); - auto layout = tensorTy.getEncoding(); - auto order = triton::gpu::getOrder(layout); - auto uniqueContigPerThread = - triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); - assert(order[0] < uniqueContigPerThread.size() && - "Unexpected uniqueContigPerThread size"); - unsigned contiguity = uniqueContigPerThread[order[0]]; - - // Get alignment from the pointer. Since this is a scalar pointer - // we should not take the pointer contiguity to consider alignment - auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr); - auto maxMultipleBytes = axisInfo->getDivisibility(0); - auto elemNumBits = triton::getPointeeBitWidth(tensorTy); - auto elemNumBytes = std::max(elemNumBits / 8, 1); - auto align = std::max(maxMultipleBytes / elemNumBytes, 1); - - // Final contiguity is a min of the offset contiguity and pointer alignment - contiguity = std::min(align, contiguity); - return contiguity; - } - - // Determine the vector size of a tensor of pointers - unsigned getVectorSize(Value ptr) const { - auto tensorTy = dyn_cast(ptr.getType()); - if (!tensorTy) - return 1; - auto contiguity = getContiguity(ptr); - auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); - return std::min(128 / pointeeBitWidth, contiguity); - } - - // Given a scalar pointer and a tensor of offsets, determine the vector size - unsigned getVectorSize(Value ptr, Value offset) const { - auto contiguity = getContiguity(ptr, offset); - auto pointeeBitWidth = triton::getPointeeBitWidth(ptr.getType()); - return std::min(128 / pointeeBitWidth, contiguity); - } - // Unpack the elements contained in a `llvmStruct` into a `SmallVector` of // `Value`s. While you do that, check also the alignment of the mask and // update the vector length `vec` accordingly @@ -288,7 +238,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, Type valueTy = op.getType(); Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); - unsigned vec = getVectorSize(ptr); + unsigned vec = getVectorSize(ptr, axisAnalysisPass); unsigned numElems = getTotalElemsPerThread(ptr.getType()); // Get the LLVM values for pointers @@ -336,7 +286,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, ptrAlignmentBytes, cacheMod); for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), ii); + rewriter, loc, getTypeConverter()->getIndexType(), ii); Value loaded = extract_element(valueElemTy, loadVal, vecIdx); loadedVals.push_back(loaded); } @@ -392,7 +342,7 @@ struct BufferLoadOpConversion typeConverter->convertType(getElementTypeOrSelf(valueTy)); Type ptrType = getPointerTypeWithShape(ptr, offset); unsigned numElems = getTotalElemsPerThread(ptrType); - unsigned vec = getVectorSize(ptr, offset); + unsigned vec = getVectorSize(ptr, offset, axisAnalysisPass); // Get the offset SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); @@ -422,7 +372,7 @@ struct BufferLoadOpConversion vecTy, rsrcDesc, offsetElems[vecStart], pred, falseVal, cacheMod); for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), ii); + rewriter, loc, getTypeConverter()->getIndexType(), ii); Value loaded = extract_element(valueElemTy, loadVal, vecIdx); loadedVals.push_back(loaded); } @@ -470,7 +420,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, typeConverter->convertType(getElementTypeOrSelf(valueTy)); // Determine the vectorization size - unsigned vec = getVectorSize(ptr); + unsigned vec = getVectorSize(ptr, axisAnalysisPass); unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); auto ptrElems = unpackLLElements(loc, llPtr, rewriter); @@ -514,6 +464,254 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, } }; +static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) { + switch (memOrdering) { + case MemSemantic::RELAXED: + return LLVM::AtomicOrdering::monotonic; + case MemSemantic::ACQUIRE: + return LLVM::AtomicOrdering::acquire; + case MemSemantic::RELEASE: + return LLVM::AtomicOrdering::release; + case MemSemantic::ACQUIRE_RELEASE: + return LLVM::AtomicOrdering::acq_rel; + default: + return LLVM::AtomicOrdering::acq_rel; + } +} + +struct BufferAtomicRMWOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern< + triton::amdgpu::BufferAtomicRMWOp>::ConvertOpToLLVMPattern; + + BufferAtomicRMWOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::BufferAtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + + // original values + Value ptr = op.getPtr(); + Value offset = op.getOffsets(); + Value mask = op.getMask(); + Value data = op.getValue(); + auto atomicRmwAttr = op.getAtomicRmwOp(); + + Value llPtr = adaptor.getPtr(); + Value llOffset = adaptor.getOffsets(); + Value llMask = adaptor.getMask(); + Value llData = adaptor.getValue(); + + // Determine the vectorization size + Type valueTy = data.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + Type ptrType = getPointerTypeWithShape(ptr, offset); + + unsigned numElems = getTotalElemsPerThread(ptrType); + unsigned vec = getVectorSize(ptr, offset, axisAnalysisPass); + + // v4f16 and v4bf16 variants of buffer atomics do not exist. + // only v2f16 and v2bf16. + if (valueElemTy.isBF16() || valueElemTy.isF16()) { + // We clamp to the only supported vectorization width here (2). + // In ConvertToBufferOps we check that we have a large enough vector size + assert(vec >= 2); + vec = 2u; + // The max width of a buffer atomic op is 64-bits + // Some types like F32 don't have a 2x vectorized version + } else if (valueElemTy.isF32() || valueElemTy.isF64() || + valueElemTy.isInteger(32) || valueElemTy.isInteger(64)) { + vec = 1u; + } + + // Get the offsets and value + SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); + SmallVector valueElems = unpackLLElements(loc, llData, rewriter); + + // Get the mask + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + // We need to manually emit memory fences (LLVM doesn't do this for buffer + // ops) see: https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942 + auto memOrdering = op.getSem(); + auto atomicMemOrdering = getMemoryOrdering(memOrdering); + auto rel = LLVM::AtomicOrdering::release; + auto acq = LLVM::AtomicOrdering::acquire; + + bool emitReleaseFence = false; + bool emitAcquireFence = false; + switch (memOrdering) { + case MemSemantic::RELAXED: + // In this case, no memory fences are needed + break; + case MemSemantic::RELEASE: + emitReleaseFence = true; + break; + case MemSemantic::ACQUIRE: + emitAcquireFence = true; + break; + case MemSemantic::ACQUIRE_RELEASE: + emitAcquireFence = true; + emitReleaseFence = true; + default: + // default == acq_rel, so we emit the same barriers + emitAcquireFence = true; + emitReleaseFence = true; + } + + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr); + Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + SmallVector loadedVals; + + // set the scope + auto memScope = op.getScope(); + auto scopeStr = ""; + switch (memScope) { + // System scope is not supported yet + case MemSyncScope::SYSTEM: + return failure(); + case MemSyncScope::GPU: + scopeStr = "agent"; + break; + case MemSyncScope::CTA: + scopeStr = "workgroup"; + break; + default: + return failure(); + } + + StringAttr scope = mlir::StringAttr::get(loc.getContext(), scopeStr); + + if (emitReleaseFence) + rewriter.create(loc, TypeRange{}, rel, scope); + + mlir::Operation *lastRMWOp; + MLIRContext *ctx = rewriter.getContext(); + GCNBuilder waitcntBuilder; + + // Triton supports three scopes for atomic access + // 1. System + // 2. GPU (default) + // 3. CTA (i.e., threadblock or warp-group) + // + // Currently, the AMD backend emits atomics with agent-scope. + // + // The following properties are used to emit proper synchronization + // primitives between sequential buffer atomics See: Memory Model GFX942 + // (MI300 series) + // https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942: + // + // buffer/global/flat_load/store/atomic instructions to global memory are + // termed vector memory operations. + // + // 1. Vector memory operations access a single vector L1 cache shared by + // all SIMDs a CU. + // No special action is required for coherence between wavefronts in the + // same work-group since they execute on the same CU. + // + // 2. Each CU has a separate request queue per channel for its associated + // L2. + // Therefore, the vector and scalar memory operations performed by + // wavefronts executing with different L1 caches and the same L2 cache + // can be reordered relative to each other. A `s_waitcnt vmcnt(0)` is + // required to ensure synchronization between vector memory operations of + // different CUs. It ensures a previous vector memory operation has + // completed before executing a subsequent vector memory or LDS operation + // and so can be used to meet the requirements of acquire and release. + // + // 3. Atomic read-modify-write instructions implicitly bypass the L1 cache + // (specific to gfx942) + // Therefore, they do not use the sc0 bit for coherence and instead use + // it to indicate if the instruction returns the original value being + // updated. They do use sc1 to indicate system or agent scope coherence. + // See the cache modifiers word in BufferEmitter::fillCommonArgs for + // more details. + // + // In summary: + // 1. We have to emit memory fences (i.e., acq/rel/acq_rel) before and after + // our buffer atomics. + // 2. Because buffer atomic rmw ops skip the l1 cache, s_waitcnt vmcnt(0) is + // sufficient for synchronization between instructions. + // We don't need to invalidate L1 between these ops on GFX942, just after + // (i.e., we can skip `buffer_wbinvl1_vol`) + // 3. We don't have to explicitly write to the l2 cache because + // `s_waitcnt vmcnt(0)` already does this as-per the MI300/CDNA3 ISA + // docs: "Decremented for reads when the data has been written back to + // the VGPRs, and for writes when the data has been written to the L2 + // cache. Ordering: Memory reads and writes return in the order they were + // issued, including mixing reads and writes" + // 4. We set GLC=1, to return the old value. Atomics in GFX942 execute with + // either device (default) or system scope (controlled by the sc1 flag). + // This is distinct from the memory scope of the atomic (i.e, the memory + // fences which appear before/after the ops). + + if (memScope == MemSyncScope::GPU) { + waitcntBuilder.create<>("s_waitcnt vmcnt(0)")->operator()(); + } else if (memScope == MemSyncScope::CTA) { + // TODO: Within a CTA we can possibly relax this? + waitcntBuilder.create<>("s_waitcnt vmcnt(0)")->operator()(); + } + + // Check if the op has users, if it does we set GLC=1, otherwise GLC=0 + auto opUsers = op.getResult().getUsers(); + auto hasUsers = std::distance(opUsers.begin(), opUsers.end()) > 0; + + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask; + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); + // Create the store val + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + valueElems, vecStart); + + Value loadVal = bufferEmitter.emitAtomicRMW( + atomicRmwAttr, vecTy, rsrcDesc, offsetElems[vecStart], storeVal, pred, + hasUsers); + // Track the last op, so we can emit a fenceop after the loop + lastRMWOp = loadVal.getDefiningOp(); + + // To sync vector memory ops between CUs within an agent, we need an + // s_waitcnt skip doing this on the last iteration of the loop + // In the relaxed memory ordering, we don't need this barrier + if (vecStart < numElems - vec && (emitReleaseFence || emitAcquireFence)) { + Value inst = + waitcntBuilder.launch(rewriter, lastRMWOp->getLoc(), void_ty(ctx)); + lastRMWOp = inst.getDefiningOp(); + } + for (size_t ii = 0; ii < vec; ++ii) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), ii); + Value loaded = extract_element(valueElemTy, loadVal, vecIdx); + loadedVals.push_back(loaded); + } + } // end vec + + // Acquire Fence post-atomic + if (emitAcquireFence) + rewriter.create(lastRMWOp->getLoc(), TypeRange{}, acq, + scope); + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, + rewriter, llvmResultStructTy); + + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + struct BufferStoreOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { @@ -553,7 +751,7 @@ struct BufferStoreOpConversion Type ptrType = getPointerTypeWithShape(ptr, offset); unsigned numElems = getTotalElemsPerThread(ptrType); - unsigned vec = getVectorSize(ptr, offset); + unsigned vec = getVectorSize(ptr, offset, axisAnalysisPass); // Get the offsets and value SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); @@ -581,21 +779,6 @@ struct BufferStoreOpConversion } }; -static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) { - switch (memOrdering) { - case MemSemantic::RELAXED: - return LLVM::AtomicOrdering::monotonic; - case MemSemantic::ACQUIRE: - return LLVM::AtomicOrdering::acquire; - case MemSemantic::RELEASE: - return LLVM::AtomicOrdering::release; - case MemSemantic::ACQUIRE_RELEASE: - return LLVM::AtomicOrdering::acq_rel; - default: - return LLVM::AtomicOrdering::acq_rel; - } -} - struct AtomicCASOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { @@ -641,7 +824,7 @@ struct AtomicCASOpConversion auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); // vec = 1 for scalar - auto vec = getVectorSize(op.getPtr()); + auto vec = getVectorSize(op.getPtr(), axisAnalysisPass); // tensor if (TensorTy) { auto valTy = cast(op.getVal().getType()); @@ -839,7 +1022,7 @@ struct AtomicRMWOpConversion const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); // vec = 1, numElements = 1 for scalar - auto vec = getVectorSize(ptr); + auto vec = getVectorSize(ptr, axisAnalysisPass); int numElems = 1; Type packF16Ty = vec_ty(valueElemTy, 2); @@ -940,13 +1123,13 @@ struct AtomicRMWOpConversion rewriter.setInsertionPointToEnd(atomicBlock); auto maybeKind = matchAtomicOp(atomicRmwAttr); - // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient - // atomics for MI-* series of AMD GPU. + Value atom = rewriter .create(loc, *maybeKind, rmwPtr, operand, atomicMemOrdering, StringRef(scopeStr.value())) .getResult(); + if (!tensorTy) { if (atomicNeedsSharedMemory(op.getResult())) { Value atomPtr = @@ -1009,9 +1192,9 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns - .add( - typeConverter, targetInfo, axisInfoAnalysis, benefit); + patterns.add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 0766b611ee88..ad7d892b3270 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -6,6 +6,7 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +using mlir::triton::ModuleAxisInfoAnalysis; using mlir::triton::AMD::DppCtrl; using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; @@ -438,19 +439,22 @@ getCacheModifierFlagsForPredicatedCall(LLVM::CallOp callOp) { // - SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system // - NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse // -// -----+-----+-----+-----+----+-- -// Op | cm | SC1 | SC0 | NT | -// -----+-----+-----+-----+----+-- -// Load | .ca | 0 | 0 | 0 | -// | .cg | 0 | 1 | 1 | -// | .cs | 0 | 1 | 1 | -// | .cv | 1 | 1 | x | -// -----+-----+-----+-----+----+-- -// Store| .wb | 0 | 0 | 0 | -// | .cg | 0 | 0 | 0 | -// | .cs | 0 | 1 | 1 | -// | .wt | 1 | x | x | -// -----+-----+-----+-----+----+-- +// -------+-----+-----+-----+----+-- +// Op | cm | SC1 | SC0 | NT | +// -------+-----+-----+-----+----+-- +// Load | .ca | 0 | 0 | 0 | +// | .cg | 0 | 1 | 1 | +// | .cs | 0 | 1 | 1 | +// | .cv | 1 | 1 | x | +// -------+-----+-----+-----+----+-- +// Store | .wb | 0 | 0 | 0 | +// | .cg | 0 | 0 | 0 | +// | .cs | 0 | 1 | 1 | +// | .wt | 1 | x | x | +// -------+-----+-----+-----+----+-- +// Atomic | N/A | 0 | 1 | x | Setting sc0 returns the pre-op value +// | N/A | 1 | 0 | x | Setting sc1 performs a system-scope atomic +// -------+-----+-----+-----+----+-- static int32_t getCtrlBitsForCacheModifierOnGFX942(triton::CacheModifier cm, bool isBufferLoad) { const int sc0Bit = 0b1, ntBit = 0b10, sc1Bit = 0b1000; @@ -481,6 +485,19 @@ static int32_t getCtrlBitsForCacheModifierOnGFX942(triton::CacheModifier cm, return aux; } +int32_t getCtrlBitsForBufferAtomicsOnGFX942(bool setSC0, bool setSC1, + bool setNT) { + const int sc0Bit = 0b1, ntBit = 0b10, sc1Bit = 0b1000; + int32_t aux = 0; + if (setSC0) + aux |= sc0Bit; + if (setSC1) + aux |= sc1Bit; + if (setNT) + aux |= ntBit; + return aux; +} + static int32_t getDefaultCtrlBitsForCacheModifier(triton::CacheModifier cm) { return 0; } @@ -520,4 +537,59 @@ Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v, return builder.launch(rewriter, loc, f16_ty, false); } +Type getPointerTypeWithShape(Value basePtr, Value offset) { + Type basePtrType = basePtr.getType(); + auto offsetType = cast(offset.getType()); + return offsetType.cloneWith(std::nullopt, basePtrType); +} + +unsigned getContiguity(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + return axisAnalysisPass.getPtrContiguity(ptr); +} + +unsigned getContiguity(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass) { + // Get contiguity from the offset + Type type = getPointerTypeWithShape(ptr, offset); + RankedTensorType tensorTy = cast(type); + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + auto uniqueContigPerThread = + triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); + assert(order[0] < uniqueContigPerThread.size() && + "Unexpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + + // Get alignment from the pointer. Since this is a scalar pointer + // we should not take the pointer contiguity to consider alignment + auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr); + auto maxMultipleBytes = axisInfo->getDivisibility(0); + auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + auto elemNumBytes = std::max(elemNumBits / 8, 1); + auto align = std::max(maxMultipleBytes / elemNumBytes, 1); + + // Final contiguity is a min of the offset contiguity and pointer alignment + contiguity = std::min(align, contiguity); + return contiguity; +} + +unsigned getVectorSize(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto contiguity = getContiguity(ptr, axisAnalysisPass); + auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); + return std::min(128 / pointeeBitWidth, contiguity); +} + +unsigned getVectorSize(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass) { + auto contiguity = getContiguity(ptr, offset, axisAnalysisPass); + auto pointeeBitWidth = triton::getPointeeBitWidth(ptr.getType()); + return std::min(128 / pointeeBitWidth, contiguity); +} + } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index a79913262a49..1dabe31db2d9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -7,6 +7,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -56,8 +57,31 @@ std::pair getCacheModifierFlagsForPredicatedCall(LLVM::CallOp); int32_t getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier, bool, mlir::triton::AMD::TargetInfo &); +// Get cache modifier information for buffer atomics +int32_t getCtrlBitsForBufferAtomicsOnGFX942(bool setSC0, bool setSC1, + bool setNT); + Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v, triton::RoundingMode rounding); + +// Return a tensor of pointers with the same type of `basePtr` and the same +// shape of `offset` +Type getPointerTypeWithShape(Value basePtr, Value offset); + +// Get contiguity for a tensor pointer `ptr` +unsigned getContiguity(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass); + +// Get contiguity for a scalar pointer `ptr` and a tensor `offset` +unsigned getContiguity(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass); + +// Determine the vector size of a tensor of pointers +unsigned getVectorSize(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass); + +// Given a scalar pointer and a tensor of offsets, determine the vector size +unsigned getVectorSize(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass); + } // namespace mlir::LLVM::AMD #endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index b9e2aaf3668e..9378596e360d 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -12,6 +12,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h" +#include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -23,11 +25,15 @@ #define GEN_PASS_CLASSES #include "TritonAMDGPUTransforms/Passes.h" +#undef DEBUG_TYPE #define DEBUG_TYPE "tritonamdgpu-convert-buffer-ops" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; +using ::mlir::LLVM::AMD::getVectorSize; +using mlir::triton::AMD::ISAFamily; + namespace ttg = mlir::triton::gpu; namespace tt = mlir::triton; @@ -227,6 +233,129 @@ bool canUseBufferOps(Value ptr, const DenseSet &assumptions) { } } // namespace +struct ConvertTritonAtomicRMWOpToBufferAtomicRMW + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertTritonAtomicRMWOpToBufferAtomicRMW( + mlir::MLIRContext *context, DenseSet &assumptions, + ModuleAxisInfoAnalysis &axisAnalysisPass) + : mlir::OpRewritePattern(context), + assumptions(assumptions), axisAnalysisPass(axisAnalysisPass) {} + + mlir::LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override { + LDBG("Try to convert: " << op); + Value ptr = op.getPtr(); + auto atomicRmwOp = op.getAtomicRmwOp(); + auto sem = op.getSem(); + auto scope = op.getScope(); + + // In addition to the `canUserBufferOps` check, we should ensure that + // 1. Perform the canUserBufferOps check + if (!canUseBufferOps(ptr, assumptions)) { + return rewriter.notifyMatchFailure(op, "canUseBufferOps check failed"); + } + + // 2. Check the scope. We support GPU and CTA for now (SYSTEM scope is not + // supported yet) + switch (scope) { + case MemSyncScope::GPU: + case MemSyncScope::CTA: + break; + default: + return rewriter.notifyMatchFailure(op, "RMW with unsupported scope"); + } + LDBG("RMW supported scope"); + + // 3. Check the memory ordering. + // TODO: support monotonic + switch (sem) { + case MemSemantic::RELAXED: + case MemSemantic::RELEASE: + case MemSemantic::ACQUIRE: + case MemSemantic::ACQUIRE_RELEASE: + break; + default: + return rewriter.notifyMatchFailure( + op, "RMW with unsupported memory ordering"); + } + + auto addPtrOp = ptr.getDefiningOp(); + Value tensorPtr = addPtrOp.getPtr(); + Value tensorOffset = addPtrOp.getOffset(); + auto splatOp = tensorPtr.getDefiningOp(); + Value basePtr = splatOp.getSrc(); + + // 4. Buffer atomic RMW does not support FP8 ops + // easier to just check what we support + auto checkType = getElementTypeOrSelf(op.getVal()); + bool isSupportedType = checkType.isF16() || checkType.isBF16() || + checkType.isF32() || checkType.isF64() || + checkType.isInteger(32) || checkType.isInteger(64); + if (!isSupportedType) { + return rewriter.notifyMatchFailure(op, "RMW with unsupported type"); + } + LDBG("RMW supported type"); + + // 5. Check if the RMWOp is supported + switch (atomicRmwOp) { + case RMWOp::AND: + case RMWOp::OR: + case RMWOp::XOR: + case RMWOp::ADD: + case RMWOp::FADD: + case RMWOp::MAX: + case RMWOp::MIN: + case RMWOp::UMAX: + case RMWOp::UMIN: + case RMWOp::XCHG: + break; + default: + auto rmwOpStr = stringifyRMWOp(atomicRmwOp).str(); + return rewriter.notifyMatchFailure(op, "RMW with unsupported op: " + + rmwOpStr); + } + LDBG("RMW supported Op"); + + // 6. Buffer atomics support 32 and 64-bit operations, so inputs must be at + // least 32-bits. Otherwise, fall back to the existing path for atomics + auto opValueType = op.getVal().getType(); + auto opBitWidth = 0; + if (auto tensorType = dyn_cast(opValueType)) { + // We can't just compute the opBitWidth using the numElements * + // elemBitWidth here. In cases such as tensor<2xf16...>, if the elements + // are contiguous we can emit the buffer op. Otherwise, the buffer ops + // lowering will try to emit individual (unsupported) f16/bf16 ops. + auto elemBitWidth = tensorType.getElementTypeBitWidth(); + opBitWidth = + getVectorSize(basePtr, tensorOffset, axisAnalysisPass) * elemBitWidth; + } else { + opBitWidth = opValueType.getIntOrFloatBitWidth(); + } + + if (opBitWidth < 32) { + return rewriter.notifyMatchFailure(op, "RMW requires opBitWidth >= 32"); + } + + Value maybeMask{}; + if (op.getMask() && !isZeroConst(op.getMask())) + maybeMask = op.getMask(); + + rewriter.replaceOpWithNewOp( + op, op.getVal().getType(), atomicRmwOp, basePtr, tensorOffset, + op.getVal(), sem, scope, maybeMask); + + return success(); + } + +private: + // Assumptions collected through the function + DenseSet assumptions; + ModuleAxisInfoAnalysis &axisAnalysisPass; +}; + struct ConvertTritonLoadToBufferLoad : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -267,7 +396,6 @@ struct ConvertTritonLoadToBufferLoad opIdxAttr); } rewriter.replaceOp(op, bufferLoadOp); - return success(); } LDBG("Failed to convert: " << op); @@ -322,13 +450,17 @@ class TritonAMDGPUConvertToBufferOpsPass public: TritonAMDGPUConvertToBufferOpsPass() = default; + TritonAMDGPUConvertToBufferOpsPass(StringRef archGen) { + this->archGenerationName = archGen.data(); + }; void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - ModuleOp m = getOperation(); + ModuleOp mod = getOperation(); + // Collect assumptions in the function DenseSet assumptions; - m.walk([&](LLVM::AssumeOp op) { + mod.walk([&](LLVM::AssumeOp op) { if (op->getOperand(0).getDefiningOp()) assumptions.insert(op->getOperand(0)); }); @@ -337,13 +469,23 @@ class TritonAMDGPUConvertToBufferOpsPass LDBG("Assumption:" << assume); } + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); patterns.add(context, assumptions); patterns.add(context, assumptions); - if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + + // Gate buffer atomics behind CDNA3 (i.e., MI300 series) for now + // GFX942-specific assumptions regarding cache coherence are made when + // lowering to LLVM + if (ISAFamily::CDNA3 == triton::AMD::deduceISAFamily(archGenerationName)) + patterns.add( + context, assumptions, axisInfoAnalysis); + + if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); } }; -std::unique_ptr mlir::createTritonAMDGPUConvertToBufferOpsPass() { - return std::make_unique(); +std::unique_ptr +mlir::createTritonAMDGPUConvertToBufferOpsPass(std::string archGen) { + return std::make_unique(archGen); } diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 1e999d631b3f..9e27e06b97e9 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -68,8 +68,9 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUOptimizeEpiloguePass); ADD_PASS_WRAPPER_0("add_canonicalize_pointers", mlir::createTritonAMDGPUCanonicalizePointersPass); - ADD_PASS_WRAPPER_0("add_convert_to_buffer_ops", - mlir::createTritonAMDGPUConvertToBufferOpsPass); + ADD_PASS_WRAPPER_1("add_convert_to_buffer_ops", + mlir::createTritonAMDGPUConvertToBufferOpsPass, + const std::string &); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); ADD_PASS_WRAPPER_0("add_block_pingpong",