From 14028ec0a62210d68a4dd7a046ac79c8c3b7727e Mon Sep 17 00:00:00 2001 From: Finn Plummer <50529406+inbelic@users.noreply.github.com> Date: Wed, 29 Nov 2023 20:32:13 +0100 Subject: [PATCH] [mlir][spirv] Add canon patterns for IAddCarry/[S|U]MulExtended (#73340) Add missing constant propogation folder for IAddCarry and [S|U]MulExtended. Due to currently missing constant value for spirv.struct the folding is done using canonicalization patterns. Implement additional folding when rhs is 0 for all ops and when rhs is 1 for UMulExt. This helps for readability of lowered code into SPIR-V. Part of work for #70704 --- .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 6 + .../SPIRV/IR/SPIRVCanonicalization.cpp | 194 ++++++++++++++++++ .../SPIRV/Transforms/canonicalize.mlir | 182 ++++++++++++++++ 3 files changed, 382 insertions(+) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 701389d1cf4c1e..51124e141c6d46 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -316,6 +316,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry", %2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> ``` }]; + + let hasCanonicalizer = 1; } // ----- @@ -551,6 +553,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended", %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> ``` }]; + + let hasCanonicalizer = 1; } // ----- @@ -675,6 +679,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended", %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> ``` }]; + + let hasCanonicalizer = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 82af41643edb89..22cb9bf718e36f 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -122,6 +122,200 @@ void spirv::AccessChainOp::getCanonicalizationPatterns( results.add(context); } +//===----------------------------------------------------------------------===// +// spirv.IAddCarry +//===----------------------------------------------------------------------===// + +// We are required to use CompositeConstructOp to create a constant struct as +// they are not yet implemented as constant, hence we can not do so in a fold. +struct IAddCarryFold final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spirv::IAddCarryOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.getOperand1(); + Value rhs = op.getOperand2(); + Type constituentType = lhs.getType(); + + // iaddcarry (x, 0) = <0, x> + if (matchPattern(rhs, m_Zero())) { + Value constituents[2] = {rhs, lhs}; + rewriter.replaceOpWithNewOp(op, op.getType(), + constituents); + return success(); + } + + // According to the SPIR-V spec: + // + // Result Type must be from OpTypeStruct. The struct must have two + // members... + // + // Member 0 of the result gets the low-order bits (full component width) of + // the addition. + // + // Member 1 of the result gets the high-order (carry) bit of the result of + // the addition. That is, it gets the value 1 if the addition overflowed + // the component width, and 0 otherwise. + Attribute lhsAttr; + Attribute rhsAttr; + if (!matchPattern(lhs, m_Constant(&lhsAttr)) || + !matchPattern(rhs, m_Constant(&rhsAttr))) + return failure(); + + auto adds = constFoldBinaryOp( + {lhsAttr, rhsAttr}, + [](const APInt &a, const APInt &b) { return a + b; }); + if (!adds) + return failure(); + + auto carrys = constFoldBinaryOp( + ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) { + APInt zero = APInt::getZero(a.getBitWidth()); + return a.ult(b) ? (zero + 1) : zero; + }); + + if (!carrys) + return failure(); + + Value addsVal = + rewriter.create(loc, constituentType, adds); + + Value carrysVal = + rewriter.create(loc, constituentType, carrys); + + // Create empty struct + Value undef = rewriter.create(loc, op.getType()); + // Fill in adds at id 0 + Value intermediate = + rewriter.create(loc, addsVal, undef, 0); + // Fill in carrys at id 1 + rewriter.replaceOpWithNewOp(op, carrysVal, + intermediate, 1); + return success(); + } +}; + +void spirv::IAddCarryOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); +} + +//===----------------------------------------------------------------------===// +// spirv.[S|U]MulExtended +//===----------------------------------------------------------------------===// + +// We are required to use CompositeConstructOp to create a constant struct as +// they are not yet implemented as constant, hence we can not do so in a fold. +template +struct MulExtendedFold final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MulOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.getOperand1(); + Value rhs = op.getOperand2(); + Type constituentType = lhs.getType(); + + // [su]mulextended (x, 0) = <0, 0> + if (matchPattern(rhs, m_Zero())) { + Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter); + Value constituents[2] = {zero, zero}; + rewriter.replaceOpWithNewOp(op, op.getType(), + constituents); + return success(); + } + + // According to the SPIR-V spec: + // + // Result Type must be from OpTypeStruct. The struct must have two + // members... + // + // Member 0 of the result gets the low-order bits of the multiplication. + // + // Member 1 of the result gets the high-order bits of the multiplication. + Attribute lhsAttr; + Attribute rhsAttr; + if (!matchPattern(lhs, m_Constant(&lhsAttr)) || + !matchPattern(rhs, m_Constant(&rhsAttr))) + return failure(); + + auto lowBits = constFoldBinaryOp( + {lhsAttr, rhsAttr}, + [](const APInt &a, const APInt &b) { return a * b; }); + + if (!lowBits) + return failure(); + + auto highBits = constFoldBinaryOp( + {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) { + unsigned bitWidth = a.getBitWidth(); + APInt c; + if (IsSigned) { + c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); + } else { + c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); + } + return c.extractBits(bitWidth, bitWidth); // Extract high result + }); + + if (!highBits) + return failure(); + + Value lowBitsVal = + rewriter.create(loc, constituentType, lowBits); + + Value highBitsVal = + rewriter.create(loc, constituentType, highBits); + + // Create empty struct + Value undef = rewriter.create(loc, op.getType()); + // Fill in lowBits at id 0 + Value intermediate = + rewriter.create(loc, lowBitsVal, undef, 0); + // Fill in highBits at id 1 + rewriter.replaceOpWithNewOp(op, highBitsVal, + intermediate, 1); + return success(); + } +}; + +using SMulExtendedOpFold = MulExtendedFold; +void spirv::SMulExtendedOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); +} + +struct UMulExtendedOpXOne final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.getOperand1(); + Value rhs = op.getOperand2(); + Type constituentType = lhs.getType(); + + // umulextended (x, 1) = + if (matchPattern(rhs, m_One())) { + Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter); + Value constituents[2] = {lhs, zero}; + rewriter.replaceOpWithNewOp(op, op.getType(), + constituents); + return success(); + } + + return failure(); + } +}; + +using UMulExtendedOpFold = MulExtendedFold; +void spirv::UMulExtendedOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // spirv.UMod //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 6fb5ca5c41839a..867ddf3c801733 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -336,6 +336,61 @@ func.func @iadd_poison(%arg0: i32) -> i32 { // ----- +//===----------------------------------------------------------------------===// +// spirv.IAddCarry +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @iaddcarry_x_0 +func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> { + // CHECK: %[[RET:.*]] = spirv.CompositeConstruct + %c0 = spirv.Constant 0 : i32 + %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)> + + // CHECK: return %[[RET]] + return %0 : !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: @const_fold_scalar_iaddcarry +func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) { + %c5 = spirv.Constant 5 : i32 + %cn5 = spirv.Constant -5 : i32 + %cn8 = spirv.Constant -8 : i32 + + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[CN3:.*]] = spirv.Constant -3 + // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef + // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN3]], %[[UNDEF1]][0 : i32] + // CHECK-DAG: %[[CC_CN3_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER1]][1 : i32] + // CHECK-DAG: %[[C1:.*]] = spirv.Constant 1 + // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13 + // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef + // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[CN13]], %[[UNDEF2]][0 : i32] + // CHECK-DAG: %[[CC_CN13_C1:.*]] = spirv.CompositeInsert %[[C1]], %[[INTER2]][1 : i32] + %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)> + %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)> + + // CHECK: return %[[CC_CN3_C0]], %[[CC_CN13_C1]] + return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: @const_fold_vector_iaddcarry +func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> { + %v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32> + %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32> + + // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[-3, -11, 0]> + // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[0, 1, 1]> + // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef + // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32] + // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32] + %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> + + // CHECK: return %[[CC_CV1_CV2]] + return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.IMul //===----------------------------------------------------------------------===// @@ -400,6 +455,133 @@ func.func @const_fold_vector_imul() -> vector<3xi32> { // ----- +//===----------------------------------------------------------------------===// +// spirv.SMulExtended +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @smulextended_x_0 +func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> { + // CHECK: %[[C0:.*]] = spirv.Constant 0 + // CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[C0]], %[[C0]] + %c0 = spirv.Constant 0 : i32 + %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)> + + // CHECK: return %[[RET]] + return %0 : !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: @const_fold_scalar_smulextended +func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) { + %c5 = spirv.Constant 5 : i32 + %cn5 = spirv.Constant -5 : i32 + %cn8 = spirv.Constant -8 : i32 + + // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40 + // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1 + // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef + // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32] + // CHECK-DAG: %[[CC_CN40_CN1:.*]] = spirv.CompositeInsert %[[CN1]], %[[INTER1]] + // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40 + // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 + // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef + // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32] + // CHECK-DAG: %[[CC_C40_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER2]][1 : i32] + %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)> + %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)> + + // CHECK: return %[[CC_CN40_CN1]], %[[CC_C40_C0]] + return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: @const_fold_vector_smulextended +func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> { + %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32> + %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32> + + // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]> + // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, 0, -1]> + // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef + // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32] + // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32] + %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> + + // CHECK: return %[[CC_CV1_CV2]] + return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> + +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.UMulExtended +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @umulextended_x_0 +func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> { + // CHECK: %[[C0:.*]] = spirv.Constant 0 + // CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[C0]], %[[C0]] + %c0 = spirv.Constant 0 : i32 + %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)> + + // CHECK: return %[[RET]] + return %0 : !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: @umulextended_x_1 +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> { + // CHECK: %[[C0:.*]] = spirv.Constant 0 + // CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[ARG]], %[[C0]] + %c0 = spirv.Constant 1 : i32 + %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)> + + // CHECK: return %[[RET]] + return %0 : !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: @const_fold_scalar_umulextended +func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) { + %c5 = spirv.Constant 5 : i32 + %cn5 = spirv.Constant -5 : i32 + %cn8 = spirv.Constant -8 : i32 + + + // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40 + // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13 + // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40 + // CHECK-DAG: %[[C4:.*]] = spirv.Constant 4 + // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef + // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32] + // CHECK-DAG: %[[CC_CN40_C4:.*]] = spirv.CompositeInsert %[[C4]], %[[INTER1]][1 : i32] + // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef + // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32] + // CHECK-DAG: %[[CC_C40_CN13:.*]] = spirv.CompositeInsert %[[CN13]], %[[INTER2]][1 : i32] + %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)> + %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)> + + // CHECK: return %[[CC_CN40_C4]], %[[CC_C40_CN13]] + return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: @const_fold_vector_umulextended +func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> { + %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32> + %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32> + + // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]> + // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, -13, 0]> + // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef + // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]] + // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]] + %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> + + // CHECK: return %[[CC_CV1_CV2]] + return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> +} + +// ----- + + //===----------------------------------------------------------------------===// // spirv.ISub //===----------------------------------------------------------------------===//