Skip to content

Commit

Permalink
[mlir][spirv] Add folding for IAddCarry/[S|U]MulExtended
Browse files Browse the repository at this point in the history
Add missing constant propogation folder for IAddCarry and
[S|U]MulExtended. Due to currently missing constant value for
spirv.struct the folding is done using canonicalization patterns.

Implement additional folding when rhs is 0 for all ops and when rhs
is 1 for UMulExt.

This helps for readability of lowered code into SPIRV.

Part of work for llvm#70704
  • Loading branch information
inbelic committed Nov 24, 2023
1 parent cc21287 commit dcbfc96
Show file tree
Hide file tree
Showing 3 changed files with 344 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];

let hasCanonicalizer = 1;
}

// -----
Expand Down Expand Up @@ -607,6 +609,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
%2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];

let hasCanonicalizer = 1;
}

// -----
Expand Down Expand Up @@ -742,6 +746,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
%2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];

let hasCanonicalizer = 1;
}

// -----
Expand Down
190 changes: 190 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
results.add<CombineChainedAccessChain>(context);
}

//===----------------------------------------------------------------------===//
// spirv.IAddCarry
//===----------------------------------------------------------------------===//

// We are required to use CompositeConstructOp to create a constant struct as
// they are not yet implemented as constant, hence we can not do so in a fold.
struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto operands = op.getOperands();

SmallVector<Value> constituents;
Type constituentType = operands[0].getType();

// iaddcarry (x, 0) = <0, x>
if (matchPattern(operands[1], m_Zero())) {
constituents.push_back(operands[1]);
constituents.push_back(operands[0]);
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}

// According to the SPIR-V spec:
//
// Result Type must be from OpTypeStruct. The struct must have two
// members...
//
// Member 0 of the result gets the low-order bits (full component width) of
// the addition.
//
// Member 1 of the result gets the high-order (carry) bit of the result of
// the addition. That is, it gets the value 1 if the addition overflowed
// the component width, and 0 otherwise.
Attribute lhs;
Attribute rhs;
if (!matchPattern(operands[0], m_Constant(&lhs)) ||
!matchPattern(operands[1], m_Constant(&rhs)))
return failure();

auto adds = constFoldBinaryOp<IntegerAttr>(
{lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
if (!adds)
return failure();

auto carrys = constFoldBinaryOp<IntegerAttr>(
ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
APInt zero = APInt::getZero(a.getBitWidth());
return a.ult(b) ? (zero + 1) : zero;
});

if (!carrys)
return failure();

Value addsVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
constituents.push_back(addsVal);

Value carrysVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
constituents.push_back(carrysVal);

rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}
};

void spirv::IAddCarryOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<IAddCarryFold>(context);
}

//===----------------------------------------------------------------------===//
// spirv.[S|U]MulExtended
//===----------------------------------------------------------------------===//

// We are required to use CompositeConstructOp to create a constant struct as
// they are not yet implemented as constant, hence we can not do so in a fold.
template <typename MulOp, bool IsSigned>
struct MulExtendedFold final : OpRewritePattern<MulOp> {
using OpRewritePattern<MulOp>::OpRewritePattern;

LogicalResult matchAndRewrite(MulOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto operands = op.getOperands();

SmallVector<Value> constituents;
Type constituentType = operands[0].getType();

// [su]mulextended (x, 0) = <0, 0>
if (matchPattern(operands[1], m_Zero())) {
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
constituents.push_back(zero);
constituents.push_back(zero);
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}

// According to the SPIR-V spec:
//
// Result Type must be from OpTypeStruct. The struct must have two
// members...
//
// Member 0 of the result gets the low-order bits of the multiplication.
//
// Member 1 of the result gets the high-order bits of the multiplication.
Attribute lhs;
Attribute rhs;
if (!matchPattern(operands[0], m_Constant(&lhs)) ||
!matchPattern(operands[1], m_Constant(&rhs)))
return failure();

auto lowBits = constFoldBinaryOp<IntegerAttr>(
{lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });

if (!lowBits)
return failure();

auto highBits = constFoldBinaryOp<IntegerAttr>(
{lhs, rhs}, [](const APInt &a, const APInt &b) {
unsigned bitWidth = a.getBitWidth();
APInt c;
if (IsSigned) {
c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
} else {
c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
}
return c.extractBits(bitWidth, bitWidth); // Extract high result
});

if (!highBits)
return failure();

Value lowBitsVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
constituents.push_back(lowBitsVal);

Value highBitsVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
constituents.push_back(highBitsVal);

rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}
};

using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
void spirv::SMulExtendedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<SMulExtendedOpFold>(context);
}

struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto operands = op.getOperands();

SmallVector<Value> constituents;
Type constituentType = operands[0].getType();

// umulextended (x, 1) = <x, 0>
if (matchPattern(operands[1], m_One())) {
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
constituents.push_back(operands[0]);
constituents.push_back(zero);
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}

return failure();
}
};

using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
void spirv::UMulExtendedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
}

//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
Expand Down
148 changes: 148 additions & 0 deletions mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,52 @@ func.func @iadd_poison(%arg0: i32) -> i32 {

// -----

//===----------------------------------------------------------------------===//
// spirv.IAddCarry
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @iaddcarry_x_0
func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
%c0 = spirv.Constant 0 : i32

// CHECK: spirv.CompositeConstruct
%0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
return %0 : !spirv.struct<(i32, i32)>
}

// CHECK-LABEL: @const_fold_scalar_iaddcarry
func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
%c5 = spirv.Constant 5 : i32
%cn5 = spirv.Constant -5 : i32
%cn8 = spirv.Constant -8 : i32

// CHECK-DAG: spirv.Constant 0
// CHECK-DAG: spirv.Constant -3
// CHECK-DAG: spirv.CompositeConstruct
// CHECK-DAG: spirv.Constant 1
// CHECK-DAG: spirv.Constant -13
// CHECK-DAG: spirv.CompositeConstruct
%0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
%1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>

return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
}

// CHECK-LABEL: @const_fold_vector_iaddcarry
func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
%v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
%v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>

// CHECK-DAG: spirv.Constant dense<[0, 1, 1]>
// CHECK-DAG: spirv.Constant dense<[-3, -11, 0]>
// CHECK-DAG: spirv.CompositeConstruct
%0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>

}

// -----

//===----------------------------------------------------------------------===//
// spirv.IMul
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -400,6 +446,108 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {

// -----

//===----------------------------------------------------------------------===//
// spirv.SMulExtended
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @smulextended_x_0
func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
%c0 = spirv.Constant 0 : i32

// CHECK: spirv.CompositeConstruct
%0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
return %0 : !spirv.struct<(i32, i32)>
}

// CHECK-LABEL: @const_fold_scalar_smulextended
func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
%c5 = spirv.Constant 5 : i32
%cn5 = spirv.Constant -5 : i32
%cn8 = spirv.Constant -8 : i32

// CHECK-DAG: spirv.Constant -40
// CHECK-DAG: spirv.Constant -1
// CHECK-DAG: spirv.CompositeConstruct
// CHECK-DAG: spirv.Constant 40
// CHECK-DAG: spirv.Constant 0
// CHECK-DAG: spirv.CompositeConstruct
%0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
%1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>

return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
}

// CHECK-LABEL: @const_fold_vector_smulextended
func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
%v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
%v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>

// CHECK: spirv.Constant dense<[2147483643, 40, -1]>
// CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
// CHECK-NEXT: spirv.CompositeConstruct
%0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>

}

// -----

//===----------------------------------------------------------------------===//
// spirv.UMulExtended
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @umulextended_x_0
func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
%c0 = spirv.Constant 0 : i32

// CHECK: spirv.CompositeConstruct
%0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
return %0 : !spirv.struct<(i32, i32)>
}

// CHECK-LABEL: @umulextended_x_1
func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
%c0 = spirv.Constant 1 : i32

// CHECK: spirv.CompositeConstruct
%0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
return %0 : !spirv.struct<(i32, i32)>
}

// CHECK-LABEL: @const_fold_scalar_umulextended
func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
%c5 = spirv.Constant 5 : i32
%cn5 = spirv.Constant -5 : i32
%cn8 = spirv.Constant -8 : i32

// CHECK-DAG: spirv.Constant 40
// CHECK-DAG: spirv.Constant -13
// CHECK-DAG: spirv.CompositeConstruct
// CHECK-DAG: spirv.Constant -40
// CHECK-DAG: spirv.Constant 4
// CHECK-DAG: spirv.CompositeConstruct
%0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
%1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>

return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
}

// CHECK-LABEL: @const_fold_vector_umulextended
func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
%v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
%v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>

// CHECK: spirv.Constant dense<[2147483643, 40, -1]>
// CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
// CHECK-NEXT: spirv.CompositeConstruct
%0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>

}

// -----


//===----------------------------------------------------------------------===//
// spirv.ISub
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit dcbfc96

Please sign in to comment.