Skip to content

Commit

Permalink
[mlir][spirv] Improve folding of MemRef to SPIRV Lowering (llvm#85433)
Browse files Browse the repository at this point in the history
Investigate the lowering of MemRef Load/Store ops and implement
additional folding of created ops

Aims to improve readability of generated lowered SPIR-V code.

Part of work llvm#70704
  • Loading branch information
inbelic authored and chencha3 committed Mar 22, 2024
1 parent dc0f0b9 commit e4b57ab
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 210 deletions.
52 changes: 28 additions & 24 deletions mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
assert(targetBits % sourceBits == 0);
Type type = srcIdx.getType();
IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
auto srcBitsValue =
builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
}

/// Returns an adjusted spirv::AccessChainOp. Based on the
Expand All @@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
Value lastDim = op->getOperand(op.getNumOperands() - 1);
Type type = lastDim.getType();
IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
auto indices = llvm::to_vector<4>(op.getIndices());
// There are two elements if this is a 1-D tensor.
assert(indices.size() == 2);
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.getComponentPtr().getType());
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
}
Expand All @@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
return srcBool;
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
zero);
}

/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
Expand All @@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
loc, builder.getIntegerType(targetBits), value);
}

value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
}
return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
offset);
return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
value, offset);
}

/// Returns true if the allocations of memref `type` generated from `allocOp`
Expand Down Expand Up @@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
return srcInt;

auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
return builder.create<spirv::IEqualOp>(loc, srcInt, one);
return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -597,25 +599,26 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// ____XXXX________ -> ____________XXXX
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
loc, spvLoadOp.getType(), spvLoadOp, offset);

// Apply the mask to extract corresponding bits.
Value mask = rewriter.create<spirv::ConstantOp>(
Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
result =
rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);

// Apply sign extension on the loading value unconditionally. The signedness
// semantic is carried in the operator itself, we relies other pattern to
// handle the casting.
IntegerAttr shiftValueAttr =
rewriter.getIntegerAttr(dstType, dstBits - srcBits);
Value shiftValue =
rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
shiftValue);
result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
shiftValue);
rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
result, shiftValue);
result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
loc, dstType, result, shiftValue);

rewriter.replaceOp(loadOp, result);

Expand Down Expand Up @@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,

// Create a mask to clear the destination. E.g., if it is the second i8 in
// i32, 0xFFFF00FF is created.
Value mask = rewriter.create<spirv::ConstantOp>(
Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
Value clearBitsMask =
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
loc, dstType, mask, offset);
clearBitsMask =
rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);

Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
Expand Down Expand Up @@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(

int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
}();

rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -991,15 +991,16 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
// broken down into progressive small steps so we can have intermediate steps
// using other dialects. At the moment SPIR-V is the final sink.

Value linearizedIndex = builder.create<spirv::ConstantOp>(
Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
loc, integerType, IntegerAttr::get(integerType, offset));
for (const auto &index : llvm::enumerate(indices)) {
Value strideVal = builder.create<spirv::ConstantOp>(
Value strideVal = builder.createOrFold<spirv::ConstantOp>(
loc, integerType,
IntegerAttr::get(integerType, strides[index.index()]));
Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
Value update =
builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
linearizedIndex =
builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
}
return linearizedIndex;
}
Expand Down
8 changes: 2 additions & 6 deletions mlir/test/Conversion/GPUToSPIRV/load-store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,9 @@ module attributes {
// CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
%13 = arith.addi %arg4, %3 : index
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
// CHECK: %[[OFFSET1_0:.*]] = spirv.Constant 0 : i32
// CHECK: %[[STRIDE1_1:.*]] = spirv.Constant 4 : i32
// CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
// CHECK: %[[OFFSET1_1:.*]] = spirv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
// CHECK: %[[STRIDE1_2:.*]] = spirv.Constant 1 : i32
// CHECK: %[[UPDATE1_2:.*]] = spirv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
// CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
// CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[INDEX1]], %[[STRIDE1_1]] : i32
// CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[INDEX2]], %[[UPDATE1_1]] : i32
// CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
// CHECK-NEXT: %[[VAL1:.*]] = spirv.Load "StorageBuffer" %[[PTR1]]
%14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>
Expand Down
Loading

0 comments on commit e4b57ab

Please sign in to comment.