From a85ce4090ec3c5c1a1b36c56c0726230f8b4596e Mon Sep 17 00:00:00 2001 From: lorenzo chelini Date: Wed, 20 Nov 2024 20:21:01 +0100 Subject: [PATCH 1/7] [MLIR][Arith] Add denormal attribute to binary/unary operations Add support for denormal in the Arith dialect (binary and unary operations). Denormal are attached to every operation, and they can be of three different kinds: 1) ieee, denormal are preserved and processed as defined by IEEE 754 rules. 2) preserve sign, a mode where denormal numbers are flushed to zero, but the sign of the zero (+0 or -0) is preserved. 3) positive zero, a mode where all denormal numbers are flushed to positive zero (+0), ignoring the sign of the original number. Denormal refers to both the operands and the result. --- .../ArithCommon/AttrToLLVMConverter.h | 6 +- .../mlir/Dialect/Arith/IR/ArithBase.td | 34 +++++++++ .../include/mlir/Dialect/Arith/IR/ArithOps.td | 22 +++--- .../Dialect/Arith/IR/ArithOpsInterfaces.td | 40 ++++++++++- mlir/include/mlir/IR/Matchers.h | 6 ++ .../Conversion/ArithToLLVM/ArithToLLVM.cpp | 67 ++++++++++++------ .../Dialect/Arith/IR/ArithCanonicalization.td | 14 ++-- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 24 ++++++- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 17 +++-- mlir/test/CAPI/ir.c | 2 +- mlir/test/Dialect/Arith/canonicalize.mlir | 23 ++++++ mlir/test/Dialect/Arith/invalid.mlir | 8 +++ mlir/test/Dialect/Arith/ops.mlir | 70 +++++++++++++++++++ mlir/test/Dialect/Linalg/invalid.mlir | 2 +- 14 files changed, 286 insertions(+), 49 deletions(-) diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h index 7ffc8613317603..da067410db5eff 100644 --- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h +++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h @@ -51,7 +51,7 @@ getLLVMDefaultFPExceptionBehavior(MLIRContext &context); template class AttrConvertFastMathToLLVM { public: - AttrConvertFastMathToLLVM(SourceOp srcOp) { + explicit AttrConvertFastMathToLLVM(SourceOp srcOp) { // Copy the source attributes. convertedAttr = NamedAttrList{srcOp->getAttrs()}; // Get the name of the arith fastmath attribute. @@ -81,7 +81,7 @@ class AttrConvertFastMathToLLVM { template class AttrConvertOverflowToLLVM { public: - AttrConvertOverflowToLLVM(SourceOp srcOp) { + explicit AttrConvertOverflowToLLVM(SourceOp srcOp) { // Copy the source attributes. convertedAttr = NamedAttrList{srcOp->getAttrs()}; // Get the name of the arith overflow attribute. @@ -109,7 +109,7 @@ class AttrConverterConstrainedFPToLLVM { "LLVM::FPExceptionBehaviorOpInterface"); public: - AttrConverterConstrainedFPToLLVM(SourceOp srcOp) { + explicit AttrConverterConstrainedFPToLLVM(SourceOp srcOp) { // Copy the source attributes. convertedAttr = NamedAttrList{srcOp->getAttrs()}; diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td index 19a2ade2e95a0e..4309c0618667a8 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -181,4 +181,38 @@ def Arith_RoundingModeAttr : I32EnumAttr< let cppNamespace = "::mlir::arith"; } +//===----------------------------------------------------------------------===// +// Arith_DenormalMode +//===----------------------------------------------------------------------===// + +// Denormal mode is applied on operands and results. For example, if denormal = +// preserve_sign, operands and results will be flushed to sign preserving zero. +// We do not distinguish between operands and results. + +// The default mode. Denormals are preserved and processed as defined +// by IEEE 754 rules. +def Arith_DenormalModeIEEE : I32BitEnumAttrCaseNone<"ieee">; + +// A mode where denormal numbers are flushed to zero, but the sign of the zero +// (+0 or -0) is preserved. +def Arith_DenormalModePreserveSign : I32BitEnumAttrCase<"preserve_sign", 1>; + +// A mode where all denormal numbers are flushed to positive zero (+0), +// ignoring the sign of the original number. +def Arith_DenormalModePositiveZero : I32BitEnumAttrCase<"positive_zero", 2>; + +def Arith_DenormalMode : I32BitEnumAttr< + "DenormalMode", "denormal mode arith", + [Arith_DenormalModeIEEE, Arith_DenormalModePreserveSign, + Arith_DenormalModePositiveZero]> { + let cppNamespace = "::mlir::arith"; + let genSpecializedAttr = 0; +} + +def Arith_DenormalModeAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; + let genVerifyDecl = 1; +} + #endif // ARITH_BASE diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 19a5e13a5d755d..4069e43af82e8e 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -61,26 +61,35 @@ class Arith_TotalIntBinaryOp traits = []> : // Base class for floating point unary operations. class Arith_FloatUnaryOp traits = []> : Arith_UnaryOp], + !listconcat([DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods], traits)>, Arguments<(ins FloatLike:$operand, DefaultValuedAttr< - Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>, + Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath, + DefaultValuedAttr< + Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>, Results<(outs FloatLike:$result)> { let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)? + (`denormal` `` $denormal^)? attr-dict `:` type($result) }]; } // Base class for floating point binary operations. class Arith_FloatBinaryOp traits = []> : Arith_BinaryOp], + !listconcat([Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods], traits)>, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs, DefaultValuedAttr< - Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>, + Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath, + DefaultValuedAttr< + Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>, Results<(outs FloatLike:$result)> { - let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)? + let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)? + (`denormal` `` $denormal^)? attr-dict `:` type($result) }]; } @@ -1085,7 +1094,6 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> { let hasFolder = 1; } - //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// @@ -1111,8 +1119,6 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> { %x = arith.mulf %y, %z : tensor<4x?xbf16> ``` - TODO: In the distant future, this will accept optional attributes for fast - math, contraction, rounding mode, and other controls. }]; let hasFolder = 1; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td index 82d6c9ad6b03da..270d80f2ec73af 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td @@ -45,13 +45,12 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> { return "fastmath"; }] > - ]; } def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> { let description = [{ - Access to op integer overflow flags. + Access to operation integer overflow flags. }]; let cppNamespace = "::mlir::arith"; @@ -108,7 +107,7 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> { let description = [{ - Access to op rounding mode. + Access to operation rounding mode. }]; let cppNamespace = "::mlir::arith"; @@ -139,4 +138,39 @@ def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> { ]; } + +def ArithDenormalModeInterface : OpInterface<"ArithDenormalModeInterface"> { + let description = [{ + Access the operation denormal modes. + }]; + + let cppNamespace = "::mlir::arith"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns a DenormalModeAttr attribute for the operation", + /*returnType=*/ "DenormalModeAttr", + /*methodName=*/ "getDenormalModeAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + return op.getDenormalAttr(); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the DenormalModeAttr attribute for + the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getDenormalModeAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "denormal"; + }] + > + ]; +} + + #endif // ARITH_OPS_INTERFACES diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 6fa5a47109d20d..226afb9ad25f1a 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -433,6 +433,12 @@ inline detail::constant_float_predicate_matcher m_NegInfFloat() { }}; } +/// Matches a constant scalar / vector splat / tensor splat with denormal +/// values. +inline detail::constant_float_predicate_matcher m_isDenormalFloat() { + return {[](const APFloat &value) { return value.isDenormal(); }}; +} + /// Matches a constant scalar / vector splat / tensor splat integer zero. inline detail::constant_int_predicate_matcher m_Zero() { return {[](const APInt &value) { return 0 == value; }}; diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index aac24f113d891f..54d941ae9f6c89 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -53,13 +53,40 @@ struct ConstrainedVectorConvertToLLVMPattern } }; +template typename AttrConvert = + AttrConvertPassThrough> +struct DenormalOpConversionToLLVMPattern + : public VectorConvertToLLVMPattern { + using VectorConvertToLLVMPattern::VectorConvertToLLVMPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Here, we need a legalization step. LLVM provides a function-level + // attribute for denormal; here, we need to move this information from the + // operation to the function, making sure all the operations in the same + // function are consistent. + if (op.getDenormalModeAttr().getValue() != arith::DenormalMode::ieee) + return rewriter.notifyMatchFailure( + op, "only ieee denormal mode is supported at the moment"); + + StringRef arithDenormalAttrName = SourceOp::getDenormalModeAttrName(); + op->removeAttr(arithDenormalAttrName); + return VectorConvertToLLVMPattern::matchAndRewrite(op, adaptor, + rewriter); + } +}; + //===----------------------------------------------------------------------===// // Straightforward Op Lowerings //===----------------------------------------------------------------------===// using AddFOpLowering = - VectorConvertToLLVMPattern; + DenormalOpConversionToLLVMPattern; using AddIOpLowering = VectorConvertToLLVMPattern; @@ -67,8 +94,8 @@ using AndIOpLowering = VectorConvertToLLVMPattern; using BitcastOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = - VectorConvertToLLVMPattern; + DenormalOpConversionToLLVMPattern; using DivSIOpLowering = VectorConvertToLLVMPattern; using DivUIOpLowering = @@ -83,38 +110,38 @@ using FPToSIOpLowering = using FPToUIOpLowering = VectorConvertToLLVMPattern; using MaximumFOpLowering = - VectorConvertToLLVMPattern; + DenormalOpConversionToLLVMPattern; using MaxNumFOpLowering = - VectorConvertToLLVMPattern; + DenormalOpConversionToLLVMPattern; using MaxSIOpLowering = VectorConvertToLLVMPattern; using MaxUIOpLowering = VectorConvertToLLVMPattern; using MinimumFOpLowering = - VectorConvertToLLVMPattern; + DenormalOpConversionToLLVMPattern; using MinNumFOpLowering = - VectorConvertToLLVMPattern; + DenormalOpConversionToLLVMPattern; using MinSIOpLowering = VectorConvertToLLVMPattern; using MinUIOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = - VectorConvertToLLVMPattern; + DenormalOpConversionToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = - VectorConvertToLLVMPattern; + DenormalOpConversionToLLVMPattern; using OrIOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = - VectorConvertToLLVMPattern; + DenormalOpConversionToLLVMPattern; using RemSIOpLowering = VectorConvertToLLVMPattern; using RemUIOpLowering = @@ -131,8 +158,8 @@ using ShRUIOpLowering = using SIToFPOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = - VectorConvertToLLVMPattern; + DenormalOpConversionToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index 6d7ac2be951dd7..22c34b2bd42f58 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -422,10 +422,11 @@ def TruncIShrUIMulIToMulUIExtended : //===----------------------------------------------------------------------===// // mulf(negf(x), negf(y)) -> mulf(x,y) -// (retain fastmath flags of original mulf) +// (retain fastmath flags and denormal mode of the original divf) def MulFOfNegF : - Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf), - (Arith_MulFOp $x, $y, $fmf), + Pat<(Arith_MulFOp (Arith_NegFOp $x, $_, $_), + (Arith_NegFOp $y, $_, $_), $fmf, $mode), + (Arith_MulFOp $x, $y, $fmf, $mode), [(Constraint> $x, $y)]>; //===----------------------------------------------------------------------===// @@ -433,10 +434,11 @@ def MulFOfNegF : //===----------------------------------------------------------------------===// // divf(negf(x), negf(y)) -> divf(x,y) -// (retain fastmath flags of original divf) +// (retain fastmath flags and denormal mode of the original divf) def DivFOfNegF : - Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf), - (Arith_DivFOp $x, $y, $fmf), + Pat<(Arith_DivFOp (Arith_NegFOp $x, $_, $_), + (Arith_NegFOp $y, $_, $_), $fmf, $mode), + (Arith_DivFOp $x, $y, $fmf, $mode), [(Constraint> $x, $y)]>; #endif // ARITH_PATTERNS diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 254f54d9e459e1..1b8a459c6e8c4b 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -952,7 +952,7 @@ void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, //===----------------------------------------------------------------------===// OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) { - /// negf(negf(x)) -> x + // negf(negf(x)) -> x if (auto op = this->getOperand().getDefiningOp()) return op.getOperand(); return constFoldUnaryOp(adaptor.getOperands(), @@ -982,6 +982,14 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_PosZeroFloat())) return getLhs(); + // Simplifies subf(x, rhs) to x if the following conditions are met: + // 1. `rhs` is a denormal floating-point value. + // 2. The denormal mode for the operation is set to positive zero. + bool isPositiveZeroMode = + getDenormalModeAttr().getValue() == DenormalMode::positive_zero; + if (isPositiveZeroMode && matchPattern(adaptor.getRhs(), m_isDenormalFloat())) + return getLhs(); + return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a - b; }); @@ -2635,6 +2643,20 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, return nullptr; } +//===----------------------------------------------------------------------===// +// DenormalModeAttr +//===----------------------------------------------------------------------===// + +LogicalResult DenormalModeAttr::verify( + llvm::function_ref emitError, + DenormalMode mode) { + auto value = static_cast(mode); + bool isSingleBitSet = (value & (value - 1)) == 0; + if (!isSingleBitSet) + return emitError() << "expected only a single denormal mode"; + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 26d9d2b091750c..ce614208bef5cc 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1501,12 +1501,17 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { std::string attrToElide; p << " { " << payloadOp->getName().getStringRef(); for (const auto &attr : payloadOp->getAttrs()) { - auto fastAttr = - llvm::dyn_cast(attr.getValue()); - if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) { - attrToElide = attr.getName().str(); - elidedAttrs.push_back(attrToElide); - break; + if (auto fastAttr = + llvm::dyn_cast(attr.getValue())) { + if (fastAttr.getValue() == arith::FastMathFlags::none) { + elidedAttrs.push_back(attr.getName().str()); + } + } + if (auto denormAttr = + llvm::dyn_cast(attr.getValue())) { + if (denormAttr.getValue() == arith::DenormalMode::ieee) { + elidedAttrs.push_back(attr.getName().str()); + } } } p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs); diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index 15a3a1fb50dc9e..fa3b0d894c995b 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -320,7 +320,7 @@ int collectStats(MlirOperation operation) { // clang-format off // CHECK-LABEL: @stats // CHECK: Number of operations: 12 - // CHECK: Number of attributes: 5 + // CHECK: Number of attributes: 6 // CHECK: Number of blocks: 3 // CHECK: Number of regions: 3 // CHECK: Number of values: 9 diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index a386a178b78995..f56bf0980b13c1 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -3189,3 +3189,26 @@ func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, } } #-} + +// ----- + +// CHECK-LABEL: @test_fold_denorm +// CHECK-SAME: %[[ARG0:.+]]: f32 +func.func @test_fold_denorm(%arg0: f32) -> f32 { + // CHECK-NOT: arith.subf + // CHECK: return %[[ARG0]] : f32 + %c_denorm = arith.constant 1.4e-45 : f32 + %sub = arith.subf %arg0, %c_denorm denormal : f32 + return %sub : f32 +} + +// ----- + +// CHECK-LABEL: @test_expect_not_to_fold_denorm +func.func @test_expect_not_to_fold_denorm(%arg0: f32, %arg1 : f32) -> (f32, f32) { + // CHECK-COUNT-2: arith.subf + %c_denorm = arith.constant 1.4e-45 : f32 + %sub = arith.subf %arg0, %c_denorm denormal : f32 + %sub_1 = arith.subf %arg1, %c_denorm denormal : f32 + return %sub, %sub_1 : f32, f32 +} diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir index 088da475e8eb4c..4999008e572fc9 100644 --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -853,3 +853,11 @@ func.func @select_tensor_encoding( %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo"> return %0 : tensor<8xi32, "foo"> } + +// ----- + +func.func @test_denormal_mode(%arg0: f32, %arg1: f32) -> f32 { + // expected-error @below{{expected only a single denormal mode}} + %0 = arith.subf %arg0, %arg1 denormal : f32 + return %0 : f32 +} \ No newline at end of file diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index f684e02344a517..c019974020879f 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -1161,3 +1161,73 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) { %3 = arith.shli %arg0, %arg1 overflow : i64 return } + +// CHECK-LABEL: check_denorm_modes +func.func @check_denorm_modes(%arg0: f32, %arg1: f32, %arg2: f32) { + %c_denorm = arith.constant 1.4e-45 : f32 + // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} denormal : f32 + %sub_preserve_sign = arith.subf %arg0, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} denormal : f32 + %sub_positive_zero = arith.subf %arg1, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} : f32 + %sub_ieee = arith.subf %arg2, %c_denorm denormal : f32 + + // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} denormal : f32 + %add_preserve_sign = arith.addf %arg0, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} denormal : f32 + %add_positive_zero = arith.addf %arg1, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} : f32 + %add_ieee = arith.addf %arg2, %c_denorm denormal : f32 + + // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} denormal : f32 + %mul_preserve_sign = arith.mulf %arg0, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} denormal : f32 + %mul_positive_zero = arith.mulf %arg1, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} : f32 + %mul_ieee = arith.mulf %arg2, %c_denorm denormal : f32 + + // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} denormal : f32 + %div_preserve_sign = arith.divf %arg0, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} denormal : f32 + %div_positive_zero = arith.divf %arg1, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} : f32 + %div_ieee = arith.divf %arg2, %c_denorm denormal : f32 + + // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} denormal : f32 + %maximumf_preserve_sign = arith.maximumf %arg0, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} denormal : f32 + %maximumf_positive_zero = arith.maximumf %arg1, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} : f32 + %maximumf_ieee = arith.maximumf %arg2, %c_denorm denormal : f32 + + // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} denormal : f32 + %maxnumf_preserve_sign = arith.maxnumf %arg0, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} denormal : f32 + %maxnumf_positive_zero = arith.maxnumf %arg1, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} : f32 + %maxnumf_ieee = arith.maxnumf %arg2, %c_denorm denormal : f32 + + // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} denormal : f32 + %minimumf_preserve_sign = arith.minimumf %arg0, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} denormal : f32 + %minimumf_positive_zero = arith.minimumf %arg1, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} : f32 + %minimumf_ieee = arith.minimumf %arg2, %c_denorm denormal : f32 + + // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} denormal : f32 + %minnumf_preserve_sign = arith.minnumf %arg0, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} denormal : f32 + %minnumf_positive_zero = arith.minnumf %arg1, %c_denorm denormal : f32 + // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} : f32 + %minnumf_ieee = arith.minnumf %arg2, %c_denorm denormal : f32 + + + // CHECK: %{{.+}} = arith.negf %{{.+}} denormal : f32 + %negf_preserve_sign = arith.negf %arg0 denormal : f32 + // CHECK: %{{.+}} = arith.negf %{{.+}} denormal : f32 + %negf_positive_sign = arith.negf %arg0 denormal : f32 + // CHECK: %{{.+}} = arith.negf %{{.+}} : f32 + %negf_ieee = arith.negf %arg0 denormal : f32 + + return +} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index a59472377a732c..e3b6958cfa8816 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -270,7 +270,7 @@ func.func @generic_result_tensor_type(%arg0: memref // ----- func.func @generic(%arg0: memref) { - // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32}} + // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) <{denormal = #arith.denormal, fastmath = #arith.fastmath}> : (f32, f32) -> f32}} linalg.generic { indexing_maps = [ affine_map<(i, j) -> (i, j)> ], iterator_types = ["parallel", "parallel"]} From e44b3e889769aede3aae42e7b0c10069d5ee7897 Mon Sep 17 00:00:00 2001 From: lorenzo chelini Date: Thu, 21 Nov 2024 20:40:03 +0100 Subject: [PATCH 2/7] new line --- mlir/test/Dialect/Arith/invalid.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir index 4999008e572fc9..ca86d51fd3523d 100644 --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -860,4 +860,4 @@ func.func @test_denormal_mode(%arg0: f32, %arg1: f32) -> f32 { // expected-error @below{{expected only a single denormal mode}} %0 = arith.subf %arg0, %arg1 denormal : f32 return %0 : f32 -} \ No newline at end of file +} From a50d31cb16619d523999a9a423b335257ab1a4d8 Mon Sep 17 00:00:00 2001 From: lorenzo chelini Date: Thu, 21 Nov 2024 21:37:18 +0100 Subject: [PATCH 3/7] plain enum --- mlir/include/mlir/Dialect/Arith/IR/ArithBase.td | 9 ++++----- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 14 -------------- mlir/test/Dialect/Arith/invalid.mlir | 8 -------- 3 files changed, 4 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td index 4309c0618667a8..d27ea5edcc8c8d 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -191,17 +191,17 @@ def Arith_RoundingModeAttr : I32EnumAttr< // The default mode. Denormals are preserved and processed as defined // by IEEE 754 rules. -def Arith_DenormalModeIEEE : I32BitEnumAttrCaseNone<"ieee">; +def Arith_DenormalModeIEEE : I32EnumAttrCase<"ieee", 0>; // A mode where denormal numbers are flushed to zero, but the sign of the zero // (+0 or -0) is preserved. -def Arith_DenormalModePreserveSign : I32BitEnumAttrCase<"preserve_sign", 1>; +def Arith_DenormalModePreserveSign : I32EnumAttrCase<"preserve_sign", 1>; // A mode where all denormal numbers are flushed to positive zero (+0), // ignoring the sign of the original number. -def Arith_DenormalModePositiveZero : I32BitEnumAttrCase<"positive_zero", 2>; +def Arith_DenormalModePositiveZero : I32EnumAttrCase<"positive_zero", 2>; -def Arith_DenormalMode : I32BitEnumAttr< +def Arith_DenormalMode : I32EnumAttr< "DenormalMode", "denormal mode arith", [Arith_DenormalModeIEEE, Arith_DenormalModePreserveSign, Arith_DenormalModePositiveZero]> { @@ -212,7 +212,6 @@ def Arith_DenormalMode : I32BitEnumAttr< def Arith_DenormalModeAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; - let genVerifyDecl = 1; } #endif // ARITH_BASE diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 1b8a459c6e8c4b..47766f36ad05cf 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2643,20 +2643,6 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, return nullptr; } -//===----------------------------------------------------------------------===// -// DenormalModeAttr -//===----------------------------------------------------------------------===// - -LogicalResult DenormalModeAttr::verify( - llvm::function_ref emitError, - DenormalMode mode) { - auto value = static_cast(mode); - bool isSingleBitSet = (value & (value - 1)) == 0; - if (!isSingleBitSet) - return emitError() << "expected only a single denormal mode"; - return success(); -} - //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir index ca86d51fd3523d..088da475e8eb4c 100644 --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -853,11 +853,3 @@ func.func @select_tensor_encoding( %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo"> return %0 : tensor<8xi32, "foo"> } - -// ----- - -func.func @test_denormal_mode(%arg0: f32, %arg1: f32) -> f32 { - // expected-error @below{{expected only a single denormal mode}} - %0 = arith.subf %arg0, %arg1 denormal : f32 - return %0 : f32 -} From 9fda8198fea96cf160eaddd32a5cdd3d40023f5e Mon Sep 17 00:00:00 2001 From: lorenzo chelini Date: Fri, 22 Nov 2024 15:56:40 +0100 Subject: [PATCH 4/7] drop llvm --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index ce614208bef5cc..98810c5f19d798 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1502,13 +1502,13 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { p << " { " << payloadOp->getName().getStringRef(); for (const auto &attr : payloadOp->getAttrs()) { if (auto fastAttr = - llvm::dyn_cast(attr.getValue())) { + dyn_cast(attr.getValue())) { if (fastAttr.getValue() == arith::FastMathFlags::none) { elidedAttrs.push_back(attr.getName().str()); } } if (auto denormAttr = - llvm::dyn_cast(attr.getValue())) { + dyn_cast(attr.getValue())) { if (denormAttr.getValue() == arith::DenormalMode::ieee) { elidedAttrs.push_back(attr.getName().str()); } From 3f5eec0e41f8d408fa9827224223e29451a62644 Mon Sep 17 00:00:00 2001 From: lorenzo chelini Date: Fri, 22 Nov 2024 15:59:50 +0100 Subject: [PATCH 5/7] format --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 98810c5f19d798..33b8f5842f61ed 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1501,14 +1501,12 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { std::string attrToElide; p << " { " << payloadOp->getName().getStringRef(); for (const auto &attr : payloadOp->getAttrs()) { - if (auto fastAttr = - dyn_cast(attr.getValue())) { + if (auto fastAttr = dyn_cast(attr.getValue())) { if (fastAttr.getValue() == arith::FastMathFlags::none) { elidedAttrs.push_back(attr.getName().str()); } } - if (auto denormAttr = - dyn_cast(attr.getValue())) { + if (auto denormAttr = dyn_cast(attr.getValue())) { if (denormAttr.getValue() == arith::DenormalMode::ieee) { elidedAttrs.push_back(attr.getName().str()); } From 7f41d63caebd22ce5b7ff9139c8bfa0272401afe Mon Sep 17 00:00:00 2001 From: lorenzo chelini Date: Fri, 22 Nov 2024 20:54:35 +0100 Subject: [PATCH 6/7] fix bug linalg --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 33b8f5842f61ed..f8281eeb21bfa8 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1498,17 +1498,16 @@ static Operation *findPayloadOp(Block *body, bool initFirst = false) { void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { SmallVector elidedAttrs; - std::string attrToElide; p << " { " << payloadOp->getName().getStringRef(); for (const auto &attr : payloadOp->getAttrs()) { if (auto fastAttr = dyn_cast(attr.getValue())) { if (fastAttr.getValue() == arith::FastMathFlags::none) { - elidedAttrs.push_back(attr.getName().str()); + elidedAttrs.push_back(attr.getName()); } } if (auto denormAttr = dyn_cast(attr.getValue())) { if (denormAttr.getValue() == arith::DenormalMode::ieee) { - elidedAttrs.push_back(attr.getName().str()); + elidedAttrs.push_back(attr.getName()); } } } From e99b6c2904e25d17013edfad334f1bc4eed77ea6 Mon Sep 17 00:00:00 2001 From: lorenzo chelini Date: Fri, 22 Nov 2024 21:02:33 +0100 Subject: [PATCH 7/7] fix test use name to make sure we don't print ieee --- mlir/test/Dialect/Arith/ops.mlir | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index c019974020879f..5892e2a3d078c7 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -1162,63 +1162,65 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) { return } -// CHECK-LABEL: check_denorm_modes +// CHECK-LABEL: check_denorm_modes( +// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) func.func @check_denorm_modes(%arg0: f32, %arg1: f32, %arg2: f32) { + // CHECK: %[[CST:.+]] = arith.constant 1.401300e-45 : f32 %c_denorm = arith.constant 1.4e-45 : f32 // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} denormal : f32 %sub_preserve_sign = arith.subf %arg0, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} denormal : f32 %sub_positive_zero = arith.subf %arg1, %c_denorm denormal : f32 - // CHECK: %{{.+}} = arith.subf %{{.+}}, %{{.+}} : f32 + // CHECK: %{{.+}} = arith.subf %[[ARG2]], %[[CST]] : f32 %sub_ieee = arith.subf %arg2, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} denormal : f32 %add_preserve_sign = arith.addf %arg0, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} denormal : f32 %add_positive_zero = arith.addf %arg1, %c_denorm denormal : f32 - // CHECK: %{{.+}} = arith.addf %{{.+}}, %{{.+}} : f32 + // CHECK: %{{.+}} = arith.addf %[[ARG2]], %[[CST]] : f32 %add_ieee = arith.addf %arg2, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} denormal : f32 %mul_preserve_sign = arith.mulf %arg0, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} denormal : f32 %mul_positive_zero = arith.mulf %arg1, %c_denorm denormal : f32 - // CHECK: %{{.+}} = arith.mulf %{{.+}}, %{{.+}} : f32 + // CHECK: %{{.+}} = arith.mulf %[[ARG2]], %[[CST]] : f32 %mul_ieee = arith.mulf %arg2, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} denormal : f32 %div_preserve_sign = arith.divf %arg0, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} denormal : f32 %div_positive_zero = arith.divf %arg1, %c_denorm denormal : f32 - // CHECK: %{{.+}} = arith.divf %{{.+}}, %{{.+}} : f32 + // CHECK: %{{.+}} = arith.divf %[[ARG2]], %[[CST]] : f32 %div_ieee = arith.divf %arg2, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} denormal : f32 %maximumf_preserve_sign = arith.maximumf %arg0, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} denormal : f32 %maximumf_positive_zero = arith.maximumf %arg1, %c_denorm denormal : f32 - // CHECK: %{{.+}} = arith.maximumf %{{.+}}, %{{.+}} : f32 + // CHECK: %{{.+}} = arith.maximumf %[[ARG2]], %[[CST]] : f32 %maximumf_ieee = arith.maximumf %arg2, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} denormal : f32 %maxnumf_preserve_sign = arith.maxnumf %arg0, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} denormal : f32 %maxnumf_positive_zero = arith.maxnumf %arg1, %c_denorm denormal : f32 - // CHECK: %{{.+}} = arith.maxnumf %{{.+}}, %{{.+}} : f32 + // CHECK: %{{.+}} = arith.maxnumf %[[ARG2]], %[[CST]] : f32 %maxnumf_ieee = arith.maxnumf %arg2, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} denormal : f32 %minimumf_preserve_sign = arith.minimumf %arg0, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} denormal : f32 %minimumf_positive_zero = arith.minimumf %arg1, %c_denorm denormal : f32 - // CHECK: %{{.+}} = arith.minimumf %{{.+}}, %{{.+}} : f32 + // CHECK: %{{.+}} = arith.minimumf %[[ARG2]], %[[CST]] : f32 %minimumf_ieee = arith.minimumf %arg2, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} denormal : f32 %minnumf_preserve_sign = arith.minnumf %arg0, %c_denorm denormal : f32 // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} denormal : f32 %minnumf_positive_zero = arith.minnumf %arg1, %c_denorm denormal : f32 - // CHECK: %{{.+}} = arith.minnumf %{{.+}}, %{{.+}} : f32 + // CHECK: %{{.+}} = arith.minnumf %[[ARG2]], %[[CST]] : f32 %minnumf_ieee = arith.minnumf %arg2, %c_denorm denormal : f32 @@ -1226,7 +1228,7 @@ func.func @check_denorm_modes(%arg0: f32, %arg1: f32, %arg2: f32) { %negf_preserve_sign = arith.negf %arg0 denormal : f32 // CHECK: %{{.+}} = arith.negf %{{.+}} denormal : f32 %negf_positive_sign = arith.negf %arg0 denormal : f32 - // CHECK: %{{.+}} = arith.negf %{{.+}} : f32 + // CHECK: %{{.+}} = arith.negf %[[ARG0]] : f32 %negf_ieee = arith.negf %arg0 denormal : f32 return