Skip to content

Commit

Permalink
[Codegen] Always use ? for non-zero offsets (#19952)
Browse files Browse the repository at this point in the history
In order to rewrite subspans to buffer descriptors, we might need to be
able to fold offsets into the buffer descriptors. This means that we
need to be able to replace an offset with a different one (specifically
0) because the offset will be applied to the base pointer during buffer
casts. If the offset were dynamic, we can always memref.cast the
dynamic-ness of the offset back in, but we can't replace a static offset
with a different static offset. Therefore, never create buffers that
have a static non-zero offset during bufferization.
  • Loading branch information
krzysz00 authored Feb 11, 2025
1 parent d81bb13 commit 4bc495b
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ struct FlattenBindingSubspan final
if (byteOffset && !matchPattern(byteOffset, m_Zero())) {
elementOffset = convertByteOffsetToElementOffset(
rewriter, loc, byteOffset, oldType.getElementType());
// The element offset needs to look dynamic.
elementOffset =
getValueOrCreateConstantIndexOp(rewriter, loc, elementOffset);
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
linearShape = affine::makeComposedFoldedAffineApply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ func.func @elementwise() {
// CHECK: func.func @elementwise()
// CHECK-DAG: %[[CST_TENSOR:.+]] = arith.constant dense_resource<__elided__> : tensor<1x10xf32>
// CHECK-DAG: %[[CST_BUF:.+]] = bufferization.to_memref %[[CST_TENSOR]]
// CHECK-DAG: %[[IN_BUF:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) {{.+}} : memref<1x10xf32, strided<[10, 1], offset: 128>, #hal.descriptor_type<storage_buffer>>
// CHECK-DAG: %[[OUT_BUF:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) {{.+}} : memref<1x10xf32, strided<[10, 1], offset: 16>, #hal.descriptor_type<storage_buffer>>
// CHECK-DAG: %[[IN_BUF:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) {{.+}} : memref<1x10xf32, strided<[10, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
// CHECK-DAG: %[[OUT_BUF:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) {{.+}} : memref<1x10xf32, strided<[10, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
// CHECK: scf.for
// CHECK-DAG: %[[SUB_IN1:.+]] = memref.subview %[[IN_BUF]]
// CHECK-DAG: %[[SUB_OUT1:.+]] = memref.subview %[[OUT_BUF]]
Expand Down Expand Up @@ -2589,8 +2589,8 @@ func.func @reduction_ew() {
}

// CHECK: func.func @reduction_ew
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c5120) : memref<1001xf32, strided<[1], offset: 1280>, #hal.descriptor_type<storage_buffer>>
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c5120) : memref<1x1001xf32, strided<[1001, 1], offset: 1280>, #hal.descriptor_type<storage_buffer>>
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c5120) : memref<1001xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c5120) : memref<1x1001xf32, strided<[1001, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) : memref<1x1001xf32, #hal.descriptor_type<storage_buffer>>

// -----
Expand Down Expand Up @@ -2714,7 +2714,7 @@ func.func @sub_byte_bufferize_with_offset() {
// CHECK-LABEL: func.func @sub_byte_bufferize_with_offset()
// CHECK: %[[C64:.+]] = arith.constant 64 : index
// CHECK: hal.interface.binding.subspan layout({{.+}}) binding(0)
// CHECK-SAME: memref<64xi4, strided<[1], offset: 128>
// CHECK-SAME: memref<64xi4, strided<[1], offset: ?>

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,13 @@ findOrCreateSubspanBuffer(RewriterBase &rewriter,
Value byteOffset = subspanOp.getByteOffset();
MemRefLayoutAttrInterface layoutAttr = {};
if (byteOffset && !matchPattern(byteOffset, m_Zero())) {
OpFoldResult elementOffset = convertByteOffsetToElementOffset(
rewriter, subspanOp->getLoc(), subspanOp.getByteOffset(),
shapedType.getBoundElementType());
std::optional<int64_t> elementOffsetInt =
getConstantIntValue(elementOffset);
if (!elementOffsetInt) {
elementOffsetInt = ShapedType::kDynamic;
}
// Using buffer resources on AMDGPU will require buffers to be relocated to
// offset 0, so any static offset we can compute here might change.
// Therefore, always use a ? for the offset field unless it's known to be 0.
auto tensorType = llvm::cast<RankedTensorType>(shapedType.getBoundType());
SmallVector<int64_t> strides = getStridesFromShape(tensorType.getShape());
layoutAttr = StridedLayoutAttr::get(rewriter.getContext(),
elementOffsetInt.value(), strides);
ShapedType::kDynamic, strides);
}
auto memRefType =
getMemrefTypeForTensor(shapedType, layoutAttr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ func.func @interleave_and_bitcast_lowering() {
%c3 = arith.constant 3 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c4096) flags(ReadOnly) : memref<128xi8, strided<[1], offset: 4096>>
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c4096) flags(ReadOnly) : memref<128xi8, strided<[1], offset: ?>>
%out_buffer = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c8192) : memref<256x64xi4, strided<[64, 1], offset: 8192>>
%2 = vector.load %0[%c0] : memref<128xi8, strided<[1], offset: 4096>>, vector<2xi8>
%2 = vector.load %0[%c0] : memref<128xi8, strided<[1], offset: ?>>, vector<2xi8>
%3 = vector.bitcast %2 : vector<2xi8> to vector<4xi4>
%4 = vector.insert %3, %cst_0 [3] : vector<4xi4> into vector<4x4xi4>
%5 = vector.bitcast %4 : vector<4x4xi4> to vector<4x2xi8>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ hal.executable @abs_ex_dispatch_0 {
func.func @abs_ex_dispatch_0() {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: 32>>
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: ?>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<16xi32>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<16xf32>
%3 = gpu.block_id x
%4 = gpu.block_dim x
%5 = gpu.thread_id x
%6 = arith.muli %3, %4 : index
%7 = arith.addi %6, %5 : index
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: 32>>
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: ?>>
%10 = memref.load %1[%7] : memref<16xi32>
%11 = arith.sitofp %10 : i32 to f32
%12 = arith.addf %9, %11 : f32
Expand Down Expand Up @@ -145,15 +145,15 @@ hal.executable @mixed_type {
func.func @mixed_type() {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%c128) : memref<16xf32, strided<[1], offset: 4>>
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%c128) : memref<16xf32, strided<[1], offset: ?>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%c0) : memref<16xi32>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : memref<16xf32>
%3 = gpu.block_id x
%4 = gpu.block_dim x
%5 = gpu.thread_id x
%6 = arith.muli %3, %4 : index
%7 = arith.addi %6, %5 : index
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: 4>>
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: ?>>
%10 = memref.load %1[%7] : memref<16xi32>
%11 = arith.sitofp %10 : i32 to f32
%12 = arith.addf %9, %11 : f32
Expand All @@ -167,8 +167,13 @@ hal.executable @mixed_type {
// CHECK-LABEL: llvm.func @mixed_type
// CHECK-SAME: (%[[ARG0:.+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
// CHECK-SAME: %{{.*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef})
// CHECK: %[[BYTES_PER_BIT:.+]] = llvm.mlir.constant(8 : i64) : i64
// CHECK: %[[BITS_PER_ELEM:.+]] = llvm.mlir.constant(32 : i64) : i64
// CHECK: %[[BYTE_OFFSET:.+]] = llvm.mlir.constant(128 : index) : i64
// CHECK: %[[OFFSET_BITS:.+]] = llvm.mul %[[BYTE_OFFSET]], %[[BYTES_PER_BIT]]
// CHECK: %[[OFFSET_ELEMS:.+]] = llvm.udiv %[[OFFSET_BITS]], %[[BITS_PER_ELEM]]
// CHECK: nvvm.read.ptx.sreg.tid.x
// CHECK: llvm.getelementptr %[[ARG0]][4] : (!llvm.ptr) -> !llvm.ptr, f32
// CHECK: llvm.getelementptr %[[ARG0]][%[[OFFSET_ELEMS]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: llvm.fadd

// -----
Expand Down Expand Up @@ -282,18 +287,18 @@ hal.executable @check_not_readonly {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<16xi32>
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: 32>>
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: ?>>
%b11 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) flags(ReadOnly) : memref<16xi32>
%b12 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%c128) : memref<16xf32, strided<[1], offset: 32>>
%b12 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%c128) : memref<16xf32, strided<[1], offset: ?>>
%b21 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) flags(ReadOnly) : memref<16xi32>
%b22 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: 32>>
%b22 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) offset(%c128) flags(ReadOnly) : memref<16xf32, strided<[1], offset: ?>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) : memref<16xf32>
%3 = gpu.block_id x
%4 = gpu.block_dim x
%5 = gpu.thread_id x
%6 = arith.muli %3, %4 : index
%7 = arith.addi %6, %5 : index
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: 32>>
%9 = memref.load %0[%7] : memref<16xf32, strided<[1], offset: ?>>
%10 = memref.load %1[%7] : memref<16xi32>
%11 = arith.sitofp %10 : i32 to f32
%12 = arith.addf %9, %11 : f32
Expand Down

0 comments on commit 4bc495b

Please sign in to comment.