-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Add basic arithmetic folds #71414
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
d5f5891
to
7129768
Compare
2c01c14
to
db168bb
Compare
We have missing basic constant folds for SPIR-V arithmetic operations which negatively impacts readability of lowered or otherwise generated code. This commit works to implementing them to improve the mentioned hinderences. Also corrects some folds that were found to be incorrect during testing. Resolves llvm#70704
We have missing basic constant folds for SPIR-V bit operations which negatively impacts readability of lowered or otherwise generated code. This commit works to implementing them to improve the mentioned hinderences. Resolves llvm#70704
We have missing basic constant folds for SPIR-V logical operations which negatively impacts readability of lowered or otherwise generated code. This commit works to implementing them to improve the mentioned hinderences. Corrects some testcases in logical-ops-to-llvm as required. Resolves llvm#70704
db168bb
to
f252ccf
Compare
Not sure on the commit granularity, but happy to rebase if we should have it more/less commits. Additionally, I plan to add the various the integer comparison ops for logical as well. But I will ask for review now to make sure it looks good so far. |
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Finn Plummer (inbelic) ChangesWe have missing basic constant folds for SPIR-V arithmetic operations which negatively impacts readability of lowered or otherwise generated code. This commit works to implementing them to improve the mentioned hindrances. Resolves #70704 Patch is 69.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71414.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..a73989c41c04cfb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -534,6 +536,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -573,6 +577,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -607,6 +613,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
%2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -634,6 +642,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
%3 = spirv.SNegate %2 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -673,6 +683,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -707,6 +719,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -742,6 +756,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
%2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -811,6 +827,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
```
}];
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 286f4de6f90f621..dbba4f7ec6cff76 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -412,6 +412,8 @@ def SPIRV_BitwiseXorOp : SPIRV_BitBinaryOp<"BitwiseXor",
%2 = spirv.BitwiseXor %0, %1 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -457,6 +459,8 @@ def SPIRV_ShiftLeftLogicalOp : SPIRV_ShiftOp<"ShiftLeftLogical",
%5 = spirv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -499,6 +503,8 @@ def SPIRV_ShiftRightArithmeticOp : SPIRV_ShiftOp<"ShiftRightArithmetic",
%5 = spirv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -542,6 +548,8 @@ def SPIRV_ShiftRightLogicalOp : SPIRV_ShiftOp<"ShiftRightLogical",
%5 = spirv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -573,6 +581,8 @@ def SPIRV_NotOp : SPIRV_BitUnaryOp<"Not", [UsableInSpecConstantOp]> {
%3 = spirv.Not %1 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
#endif // MLIR_DIALECT_SPIRV_IR_BIT_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index cf38c15d20dc326..0053cd5fc9448b5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -473,6 +473,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -506,6 +508,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -644,6 +648,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
%2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -713,7 +719,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
%2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
```
}];
- let hasFolder = true;
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..ba2281d30bdb589 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -69,6 +69,14 @@ static Attribute extractCompositeElement(Attribute composite,
return {};
}
+static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
+ bool div0 = b.isZero();
+
+ bool overflow = a.isMinSignedValue() && b.isAllOnes();
+
+ return div0 || overflow;
+}
+
//===----------------------------------------------------------------------===//
// TableGen'erated canonicalizers
//===----------------------------------------------------------------------===//
@@ -115,6 +123,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
results.add<CombineChainedAccessChain>(context);
}
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // iaddcarry (x, 0) = <0, x>
+ if (matchPattern(operands[1], m_Zero())) {
+ constituents.push_back(operands[1]);
+ constituents.push_back(operands[0]);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits (full component width) of
+ // the addition.
+ //
+ // Member 1 of the result gets the high-order (carry) bit of the result of
+ // the addition. That is, it gets the value 1 if the addition overflowed
+ // the component width, and 0 otherwise.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto adds = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+ if (!adds)
+ return failure();
+
+ auto carrys = constFoldBinaryOp<IntegerAttr>(
+ ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return a.ult(b) ? (zero + 1) : zero;
+ });
+
+ if (!carrys)
+ return failure();
+
+ Value addsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+ constituents.push_back(addsVal);
+
+ Value carrysVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+ constituents.push_back(carrysVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+ using OpRewritePattern<MulOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MulOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // [su]mulextended (x, 0) = <0, 0>
+ if (matchPattern(operands[1], m_Zero())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ constituents.push_back(zero);
+ constituents.push_back(zero);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits of the multiplication.
+ //
+ // Member 1 of the result gets the high-order bits of the multiplication.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto lowBits = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+ if (!lowBits)
+ return failure();
+
+ auto highBits = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) {
+ unsigned bitWidth = a.getBitWidth();
+ APInt c;
+ if (IsSigned) {
+ c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+ } else {
+ c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+ }
+ return c.extractBits(bitWidth, bitWidth); // Extract high result
+ });
+
+ if (!highBits)
+ return failure();
+
+ Value lowBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+ constituents.push_back(lowBitsVal);
+
+ Value highBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+ constituents.push_back(highBitsVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // umulextended (x, 1) = <x, 0>
+ if (matchPattern(operands[1], m_One())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ constituents.push_back(operands[0]);
+ constituents.push_back(zero);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
@@ -278,7 +476,7 @@ OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
// x - x = 0
if (getOperand1() == getOperand2())
- return Builder(getContext()).getIntegerAttr(getType(), 0);
+ return Builder(getContext()).getZeroAttr(getType());
// According to the SPIR-V spec:
//
@@ -290,6 +488,178 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
[](APInt a, const APInt &b) { return std::move(a) - b; });
}
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
+ // sdiv (x, 1) = x
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer division of Operand 1 divided by Operand 2.
+ // Results are computed per component. Behavior is undefined if Operand 2 is
+ // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
+ // representable value for the operands' type, causing signed overflow.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ return a.sdiv(b);
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
+ // smod (x, 1) = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getZeroAttr(getType());
+
+ // According to SPIR-V spec:
+ //
+ // Signed remainder operation for the remainder whose sign matches the sign
+ // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
+ // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+ // value for the operands' type, causing signed overflow. Otherwise, the
+ // result is the remainder r of Operand 1 divided by Operand 2 where if
+ // r ≠ 0, the sign of r is the same as the sign of Operand 2.
+ //
+ // So don't fold during undefined behaviour
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ APInt c = a.abs().urem(b.abs());
+ if (c.isZero())
+ return c;
+ if (b.isNegative()) {
+ APInt zero = APInt::getZero(c.getBitWidth());
+ return a.isNegative() ? (zero - c) : (b + c);
+ }
+ return a.isNegative() ? (b - c) : c;
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+ // -(-x) = 0 - (0 - x) = x
+ auto op = getOperand();
+ if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+ return negateOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer subtract of Operand from zero.
+ return constFoldUnaryOp<IntegerAttr>(
+ adaptor.getOperands(), [](const APInt &a) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return zero - a;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
+ // x % 1 = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getZeroAttr(getType());
+
+ // According to SPIR-V spec:
+ //
+ // Signed remainder operation for the remainder whose sign matches the sign
+ // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
+ // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+ // value for the operands' type, causing signed overflow. Otherwise, the
+ // result is the remainder r of Operand 1 divided by Operand 2 where if
+ // r ≠ 0, the sign of r is the same as the sign of Operand 1.
+
+ // Don't fold if it would do undefined behaviour.
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](APInt a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ return a.srem(b);
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
+ // udiv (x, 1) = x
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
+ // undefined if Operand 2 is 0.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0 = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0 || b.isZero()) {
+ div0 = true;
+ return a;
+ }
+ return a.udiv(b);
+ });
+ return div0 ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
+ // umod (x, 1) = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getZeroAttr(getType());
+
+ // According to the SPIR-V spec:
+ //
+ // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
+ // undefined if Operand 2 is 0.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0 = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0 || b.isZero()) {
+ div0 = true;
+ return a;
+ }
+ return a.urem(b);
+ });
+ return div0 ? Attribute() : res;
+}
+
//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
@@ -309,6 +679,32 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
+ // x == x -> true
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), true);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), true);
+ return DenseElementsAttr::get(vtType, element);
+ }
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(1);
+ return a == b ? (zero + 1) : zero;
+ });
+}
+
//===-----------------------------------...
[truncated]
|
Wow, thanks for contributing this @inbelic!
Yes, splitting into multiple smaller PRs will make this much easier to review. I'd keep them at the level of granularity of up to a few related ops (for example logical ops like and/or can go together, signed/unsigned variants of the same operation, etc.).
SG! |
@kuhar :) Okay, makes sense thanks. Will close this pr then and create the others hopefully later this week. |
We have missing basic constant folds for SPIR-V arithmetic operations which negatively impacts readability of lowered or otherwise generated code. This commit works to implementing them to improve the mentioned hindrances.
Resolves #70704