Skip to content

Commit

Permalink
[AMD] Add support for buffer atomic RMW (#5549)
Browse files Browse the repository at this point in the history
# Overview

This PR enables the raw.ptr.buffer.atomic.* RMW ops in the AMD backend.
They feature similar calling conventions and semantics to the other
buffer ops in the AMD backend.

The new ops are gated behind the `AMDGCN_ENABLE_BUFFER_ATOMICS`
environment variable which must be used in conjunction with
`AMDGCN_USE_BUFFER_OPS`. They are also gated behind the GPU being CDNA3
(MI300-series GPUs) for now as the optimizations I added make
assumptions regarding GFX942.

I originally started exploratory work on the PR to better understand the
comment in `LoadStoreOpToLLVM.cpp` referring to buffer atomics as "more
efficient". In short I found that on their own they aren't necessarily
more efficient, but using them in conjunction with more careful control
over how cache coherence ops/memory fences are emitted can improve
performance by a significant fraction.

# How

I've added a new buffer atomic RMW op in the AMDGPUOps dialect which has
its own lowering in the backend. There are a number of checks in place
to ensure that the lowering is done correctly between the
ConvertToBufferOps pass and the LoadStoreOpToLLVM lowering.

The actual lowering is where most of the performance gains come from. At
a high-level, when non-buffer atomic RMW ops are emitted, the memory
fences lower to something along the lines of:
```python
buffer_wbl2 sc1
s_waitcnt lgkmcnt(0)
atomicRMWop()
s_waitcnt vmcnt(0) 
buffer_inv sc1
buffer_wbl2 sc1
s_waitcnt lgkmcnt(0)
atomicRMWop()
s_waitcnt vmcnt(0) 
buffer_inv sc1
```


If my understanding of the [GFX942 memory
model](https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942) is
correct, then given several assumptions regarding CDNA3, this can
actually be lowered to something that resembles:
```python
buffer_wbl2 sc1
s_waitcnt lgkmcnt(0)
atomicRMWop()
s_waitcnt vmcnt(0) # AMDGCN specific cross-CU synchronization primitive
atomicRMWop()
s_waitcnt vmcnt(0) 
buffer_inv sc1
```

There are comments in the code which explain the thought process for why
(I think) that this is okay.

It appears the AMD's CK library (AMD version of CUTLASS) uses similar
synchronization mechanisms, although I am probably missing some of the
context here for sure
(https://github.com/ROCm/composable_kernel/blob/9e95d54cd2160dffc07c1197951a9ab1ca6c35f2/include/ck_tile/core/arch/amd_buffer_addressing.hpp#L619).

# Results and Testing

In addition to the added lit test, I ran the existing atomic rmw tests
in tree with buffer ops + buffer atomics enabled and they appear to
pass.

Following this, I evaluated FP16 Split-K
[gemm](https://github.com/pytorch-labs/tritonbench/blob/a2f668e38ec55978bfcf2a6a8d15294a5b9d3d36/tritonbench/operators/gemm/kernels/matmul.py#L190)
with [llama
shapes](https://github.com/pytorch-labs/tritonbench/blob/a2f668e38ec55978bfcf2a6a8d15294a5b9d3d36/tritonbench/utils/triton_op.py#L149)
in tritonbench using an MI300x. Some minor modifications to the kernel
were made to emit buffer ops (e.g., tl.assume calls). For testing
purposes, I disabled the non split-k configurations. I also checked the
numerical accuracy with rtol=atol=1e-4 for all shapes here.

<img width="768" alt="image"
src="https://github.com/user-attachments/assets/83b40b22-675a-410f-a44d-a138d2387935"
/>

Each bucket in the figure above corresponds to the average TFlops of all
shapes with the same shared `M`-dim.

At smaller batch sizes the performance is roughly equivalent. At BS=32,
buffer atomics have ~50% greater TFlops. At BS=256 buffer atomics have
~3.75x the TFlops.

Note: the purpose of this test is to evaluate the performance of buffer
atomics---split-k is not always optimal for these shapes/workload etc...
  • Loading branch information
SamGinzburg authored Jan 17, 2025
1 parent 94f80f4 commit 6556ec6
Show file tree
Hide file tree
Showing 16 changed files with 735 additions and 116 deletions.
37 changes: 37 additions & 0 deletions test/Conversion/amd/buffer_load_store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
tt.return
}
}

// -----

#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
// CHECK-LABEL: buffer_atomic
tt.func @buffer_atomic_rmw_fadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}, %N: i32, %values : tensor<128xf32, #blocked0>) {
%c128_i32 = arith.constant 128 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c128_i32 : i32
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
%3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<128xi32, #blocked0>
%5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0>
%mask = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0>
// CHECK: %[[mask0:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)>
// There should be a single release fence before any atomics
// CHECK: llvm.fence syncscope("agent") release
// CHECK: %[[mask1:.*]] = llvm.and %[[mask0]], {{.*}}
// CHECK: %[[offset:.*]] = llvm.select %[[mask1]]

// We will have 4 calls to fadd, since the sizePerThread is 4. We should have a vmcnt between each call.
%ret = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %values, %arg0[%offset], %mask : tensor<128xf32, #blocked0>

// CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void
// CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void
// CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void
// CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32

// There should be a single acquire fence after all of the atomics
// CHECK: llvm.fence syncscope("agent") acquire
tt.return
}
}
30 changes: 29 additions & 1 deletion test/TritonGPU/amd/amd-convert-buffer-ops.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops | FileCheck %s
// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops='arch-generation-name=gfx940'| FileCheck %s

#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
Expand Down Expand Up @@ -482,3 +482,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: assume_positive_offset_buffer_atomic
tt.func @assume_positive_offset_buffer_atomic(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>{
%c1024_i32 = arith.constant 1024 : i32
%c128_i32 = arith.constant 128 : i32
%c0_i32 = arith.constant 0 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%sub = arith.subi %1, %c128_i32 : i32
%cmp = arith.cmpi sgt, %sub, %c0_i32 : i32
llvm.intr.assume %cmp : i1
%2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked>
%3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
// CHECK: %[[offset:.*]] = arith.addi
%4 = arith.addi %2, %3 : tensor<1024xi32, #blocked>
// CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
%6 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
// CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]]
%8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>
tt.return %8 : tensor<1024xf32, #blocked>
}
}
3 changes: 2 additions & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def make_ttgir(mod, metadata, options):
passes.ttgpuir.add_optimize_dot_operands(pm, True)

stream_prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1"
use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"

# The `local-prefetch` scheduling variant requires turning on buffer ops.
if options.instruction_sched_variant == "local-prefetch":
Expand Down Expand Up @@ -247,7 +248,7 @@ def make_ttgir(mod, metadata, options):
if use_buffer_ops:
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm)
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
Expand Down
1 change: 1 addition & 0 deletions third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Traits.h"

// clang-format off
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TRITON_AMDGPU_ATTRDEFS

include "mlir/IR/AttrTypeBase.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "TritonAMDGPUDialect.td"
include "mlir/IR/EnumAttr.td"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def TritonAMDGPU_Dialect : Dialect {
TritonAMDGPU Dialect hosts AMD specific ops at TritonGPU abstraction level.
}];

let dependentDialects = [];
let dependentDialects = ["triton::TritonDialect"];

let useDefaultAttributePrinterParser = 1;
let usePropertiesForAttributes = 1;
Expand Down
49 changes: 47 additions & 2 deletions third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@
#define TRITON_AMDGPU_OPS

include "mlir/IR/OpBase.td"
include "triton/Dialect/Triton/IR/TritonDialect.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"

include "mlir/IR/EnumAttr.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "TritonAMDGPUDialect.td"
include "TritonAMDGPUAttrDefs.td"

Expand Down Expand Up @@ -209,6 +212,48 @@ def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [
}];
}

def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
SameLoadStoreOperandsAndResultEncoding,
MemoryEffects<[MemRead<GlobalMemory>]>,
MemoryEffects<[MemWrite<GlobalMemory>]>,
TypesMatchWith<"result element type matches the value type", "result", "value", "$_self">,
TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">,
TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">,
TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">,
TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">,
TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
]>{
let summary = "Atomic RMW op which reads, modifies, and writes to a scalar base pointer and a tensor offset";
let description = [{
AMD Buffer atomic RMW operation. Buffer atomics are similar to normal atomics, but access global memory via a
scalar base pointer and a tensor of offsets instead of a tensor of pointers.
Similar to other buffer ops, the `mask` is a boolean vector that determines if a given element should be processed with
the atomic RMW op. Elements with `mask[i] == 0` are dropped (i.e., the atomic is not executed).
Similar to TT_AtomicRMWOp: Buffer atomic RMW ops load data at $ptr, do $rmw_op with $val, and store result to $ptr with
the specified memory semantics and scope. Atomic RMW ops return the pre-op value if used, otherwise the value is implicitly dropped.
}];
let arguments = (
ins
TT_AtomicRMWAttr:$atomic_rmw_op,
TT_Ptr:$ptr,
I32Tensor:$offsets,
TT_Tensor:$value,
TT_MemSemanticAttr:$sem,
TT_MemSyncScopeAttr:$scope,
Optional<TT_BoolTensor>:$mask
);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
$atomic_rmw_op `,` $sem `,` $scope `,` $value `,` $ptr `[` $offsets `]` (`,` $mask^)?
attr-dict `:` type($result)
}];
}


def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [
SameLoadStoreOperandsEncoding,
MemoryEffects<[MemWrite<GlobalMemory>]>,
Expand Down
3 changes: 2 additions & 1 deletion third_party/amd/include/TritonAMDGPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ std::unique_ptr<Pass> createTritonAMDGPUOptimizeEpiloguePass();

std::unique_ptr<Pass> createTritonAMDGPUCanonicalizePointersPass();

std::unique_ptr<Pass> createTritonAMDGPUConvertToBufferOpsPass();
std::unique_ptr<Pass> createTritonAMDGPUConvertToBufferOpsPass(
std::string archGenName = std::string());

std::unique_ptr<Pass> createTritonAMDGPUBlockPingpongPass();

Expand Down
8 changes: 7 additions & 1 deletion third_party/amd/include/TritonAMDGPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,17 @@ def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "
def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "mlir::ModuleOp"> {
let summary = "Convert memory operations to buffer operations";

let description = "This pass converts memory operations (e.g., tt.load/tt.store) to amdgpu buffer operations, if possible";
let description = "This pass converts memory and atomic operations (e.g., tt.load/tt.store/tt.atomic_rmw) to amdgpu buffer operations, if possible";

let constructor = "mlir::createTritonAMDGPUConvertToBufferOpsPass()";

let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"];

let options = [
Option<"archGenerationName", "arch-generation-name",
"std::string", /*default=*/"std::string{}",
"GFX generation name of target device.">,
];
}

def TritonAMDGPUBlockPingpong: Pass<"tritonamdgpu-block-pingpong", "mlir::ModuleOp"> {
Expand Down
81 changes: 74 additions & 7 deletions third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset,
triton::CacheModifier cm) {
SmallVector<Value, 6> args;
fillCommonArgs(type, rsrcDesc, offset, pred, cm, /*isBufferLoad=*/true, args);
Type bufferType = getBufferOpType(type);
Type bufferType = getBufferOpType(type, false);
Value data = rewriter.create<ROCDL::RawPtrBufferLoadOp>(
loc, bufferType, args, ArrayRef<NamedAttribute>());
data = bitcast(data, type);
Expand All @@ -86,10 +86,34 @@ Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset,
return data;
}

Value BufferEmitter::emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc,
Value offset, Value data, Value pred,
bool hasUsers) {
VectorType vecTy = cast<VectorType>(data.getType());
Type bufferType = getBufferOpType(type, true);
if (vecTy != bufferType)
data = bitcast(data, bufferType);

SmallVector<Value, 6> args{data};
fillCommonArgsAtomics(type, rsrcDesc, offset, pred, hasUsers, args);

// TODO:
// The ops in ROCDL (e.g., RawPtrBufferAtomicFaddOp) have no return value,
// but they lower to instrinsics that can return values. This causes the
// LLVM verifier to fail. When this is fixed, the ROCDL ops should be used
// here.
auto rmwOpStr = stringifyRMWOp(rmwType).str();
auto instrinsic = "llvm.amdgcn.raw.ptr.buffer.atomic." + rmwOpStr;
auto bufferAtomicRMW = LLVM::createLLVMIntrinsicCallOp(
rewriter, loc, instrinsic, bufferType, args);

return bitcast(bufferAtomicRMW.getResult(0), type);
}

void BufferEmitter::emitStore(Value rsrcDesc, Value offset, Value data,
Value pred, triton::CacheModifier cm) {
VectorType vecTy = cast<VectorType>(data.getType());
Type bufferType = getBufferOpType(vecTy);
Type bufferType = getBufferOpType(vecTy, false);
if (vecTy != bufferType)
data = bitcast(data, bufferType);
SmallVector<Value, 6> args{data};
Expand All @@ -99,7 +123,7 @@ void BufferEmitter::emitStore(Value rsrcDesc, Value offset, Value data,
ArrayRef<NamedAttribute>());
}

Type BufferEmitter::getBufferOpType(Type type) {
Type BufferEmitter::getBufferOpType(Type type, bool atomicsOp) {
int64_t vecSize = 1;
Type elementType = type;
if (auto vecType = dyn_cast<VectorType>(type)) {
Expand All @@ -110,16 +134,20 @@ Type BufferEmitter::getBufferOpType(Type type) {
const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth());
const size_t totalWidthBits = valueElemNBits * vecSize;

// For bf16, always convert to i16
Type bufferElementType = elementType;
if (elementType.isBF16())
// We don't want to cast from bf16 if we are emitting buffer atomics
if (elementType.isBF16() && !atomicsOp) {
bufferElementType = rewriter.getI16Type();
}

// If we are dealing with a subword type (e.g., i8 or f16) but we
// still need multiple words, then pack the subwords into 32bit integers
// and update the vector length and the type
// We never need to pack for buffer atomics because we ensure
// 1) We can always emit a 32-bit / 64-bit atomics op
// 2) For tensors of 16-bit values that the values are contiguous
int64_t bufferVecSize = vecSize;
if (valueElemNBits < 32) {
if (valueElemNBits < 32 && !atomicsOp) {
if (totalWidthBits > 32) {
bufferElementType = rewriter.getI32Type();
bufferVecSize = totalWidthBits / 32;
Expand Down Expand Up @@ -166,10 +194,49 @@ void BufferEmitter::fillCommonArgs(Type type, Value rsrcDesc,
getCtrlBitsForCacheModifierOnTarget(cm, isBufferLoad, targetInfo);
Value cacheModifiers = int_val(32, aux);

// 5. Add the arguments
// 4. Add the arguments
args.push_back(rsrcDesc);
args.push_back(maskedOffsetBytes);
args.push_back(sgprOffset);
args.push_back(cacheModifiers);
}

void BufferEmitter::fillCommonArgsAtomics(Type type, Value rsrcDesc,
Value vOffsetElems, Value pred,
bool hasUsers,
SmallVector<Value> &args) {

// 1. Create the (masked) offset
Type elementType = getElementTypeOrSelf(type);
const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth());
const int elementByteWidth = valueElemNBits / 8;
// Please note: the index passed is not in bytes, but in number of elements
// In order to pass the index to the buffer operation, we need to convert in
// bytes (i.e., we need to multiply by `elementByteWidth`)
Value vOffsetOutOfBunds = int_val(
32, static_cast<int>(std::numeric_limits<int>::max() + int64_t(1)));
Value vOffsetBytes = mul(int_val(32, elementByteWidth), vOffsetElems);
Value maskedOffsetBytes = select(pred, vOffsetBytes, vOffsetOutOfBunds);

// 2. Set the sgprOffset to 0
Value sgprOffset = int_val(32, 0);

// 3. Create the cache modifiers word
int32_t aux = 0;
if (hasUsers)
aux = getCtrlBitsForBufferAtomicsOnGFX942(/*setSC0*/ true, /*setSC1*/ false,
/*setNT*/ false);
else
aux = getCtrlBitsForBufferAtomicsOnGFX942(
/*setSC0*/ false, /*setSC1*/ false, /*setNT*/ false);

Value cacheModifiers = int_val(32, aux);

// 4. Add the arguments
args.push_back(rsrcDesc);
args.push_back(maskedOffsetBytes);
args.push_back(sgprOffset);
args.push_back(cacheModifiers);
}

} // namespace mlir::LLVM::AMD
12 changes: 11 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include <cstdint>

namespace mlir::LLVM::AMD {
// Utility class to take care of buffer operation emission. We may add more
Expand Down Expand Up @@ -69,6 +70,10 @@ struct BufferEmitter {
Value emitLoad(Type type, Value rsrcDesc, Value offset, Value pred,
Value falseVal, CacheModifier cm);

// Emit a predicated rocdl.raw.ptr.buffer.atomic.* RMWOp
Value emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, Value offset,
Value data, Value pred, bool hasUsers);

// Emit a predicated rocdl.raw.ptr.buffer.store
void emitStore(Value rsrcDesc, Value offset, Value data, Value pred,
CacheModifier cm);
Expand All @@ -79,10 +84,15 @@ struct BufferEmitter {
CacheModifier cm, bool isBufferLoad,
SmallVector<Value> &args);

// Fill buffer atomics arguments
void fillCommonArgsAtomics(Type type, Value rsrcDesc, Value vOffsetElems,
Value pred, bool hasUsers,
SmallVector<Value> &args);

// Given a type, the buffer type can be either the same type
// or a packed version. E.g., a vector of 8xfp16 can be bitcasted to
// a vector of 4xi32. This usually makes the life of the backend easier
Type getBufferOpType(Type type);
Type getBufferOpType(Type type, bool atomicsOp);

// Rewriter utilities
RewriterBase &rewriter;
Expand Down
Loading

0 comments on commit 6556ec6

Please sign in to comment.