Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Arith] Add denormal attribute to binary/unary operations #112700

Merged
merged 7 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
template <typename SourceOp, typename TargetOp>
class AttrConvertFastMathToLLVM {
public:
AttrConvertFastMathToLLVM(SourceOp srcOp) {
explicit AttrConvertFastMathToLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith fastmath attribute.
Expand Down Expand Up @@ -81,7 +81,7 @@ class AttrConvertFastMathToLLVM {
template <typename SourceOp, typename TargetOp>
class AttrConvertOverflowToLLVM {
public:
AttrConvertOverflowToLLVM(SourceOp srcOp) {
explicit AttrConvertOverflowToLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith overflow attribute.
Expand Down Expand Up @@ -109,7 +109,7 @@ class AttrConverterConstrainedFPToLLVM {
"LLVM::FPExceptionBehaviorOpInterface");

public:
AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
explicit AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};

Expand Down
33 changes: 33 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,37 @@ def Arith_RoundingModeAttr : I32EnumAttr<
let cppNamespace = "::mlir::arith";
}

//===----------------------------------------------------------------------===//
// Arith_DenormalMode
//===----------------------------------------------------------------------===//

// Denormal mode is applied on operands and results. For example, if denormal =
// preserve_sign, operands and results will be flushed to sign preserving zero.
// We do not distinguish between operands and results.

// The default mode. Denormals are preserved and processed as defined
// by IEEE 754 rules.
def Arith_DenormalModeIEEE : I32EnumAttrCase<"ieee", 0>;

// A mode where denormal numbers are flushed to zero, but the sign of the zero
// (+0 or -0) is preserved.
def Arith_DenormalModePreserveSign : I32EnumAttrCase<"preserve_sign", 1>;

// A mode where all denormal numbers are flushed to positive zero (+0),
// ignoring the sign of the original number.
def Arith_DenormalModePositiveZero : I32EnumAttrCase<"positive_zero", 2>;

def Arith_DenormalMode : I32EnumAttr<
"DenormalMode", "denormal mode arith",
[Arith_DenormalModeIEEE, Arith_DenormalModePreserveSign,
Arith_DenormalModePositiveZero]> {
let cppNamespace = "::mlir::arith";
let genSpecializedAttr = 0;
}

def Arith_DenormalModeAttr :
EnumAttr<Arith_Dialect, Arith_DenormalMode, "denormal"> {
let assemblyFormat = "`<` $value `>`";
kuhar marked this conversation as resolved.
Show resolved Hide resolved
}

#endif // ARITH_BASE
22 changes: 14 additions & 8 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,35 @@ class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
// Base class for floating point unary operations.
class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
Arith_UnaryOp<mnemonic,
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<ArithDenormalModeInterface>],
traits)>,
Arguments<(ins FloatLike:$operand,
DefaultValuedAttr<
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
DefaultValuedAttr<
Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>,
Results<(outs FloatLike:$result)> {
let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
(`denormal` `` $denormal^)?
attr-dict `:` type($result) }];
}

// Base class for floating point binary operations.
class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic,
!listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>],
!listconcat([Pure,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<ArithDenormalModeInterface>],
traits)>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
DefaultValuedAttr<
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
DefaultValuedAttr<
Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>,
Results<(outs FloatLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
(`denormal` `` $denormal^)?
attr-dict `:` type($result) }];
}

Expand Down Expand Up @@ -1085,7 +1094,6 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
let hasFolder = 1;
}


//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
Expand All @@ -1111,8 +1119,6 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
%x = arith.mulf %y, %z : tensor<4x?xbf16>
```

TODO: In the distant future, this will accept optional attributes for fast
math, contraction, rounding mode, and other controls.
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
Expand Down
40 changes: 37 additions & 3 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,12 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
return "fastmath";
}]
>

];
}

def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
let description = [{
Access to op integer overflow flags.
Access to operation integer overflow flags.
}];

let cppNamespace = "::mlir::arith";
Expand Down Expand Up @@ -108,7 +107,7 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI

def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
let description = [{
Access to op rounding mode.
Access to operation rounding mode.
}];

let cppNamespace = "::mlir::arith";
Expand Down Expand Up @@ -139,4 +138,39 @@ def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
];
}


def ArithDenormalModeInterface : OpInterface<"ArithDenormalModeInterface"> {
let description = [{
Access the operation denormal modes.
}];

let cppNamespace = "::mlir::arith";

let methods = [
InterfaceMethod<
/*desc=*/ "Returns a DenormalModeAttr attribute for the operation",
/*returnType=*/ "DenormalModeAttr",
/*methodName=*/ "getDenormalModeAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getDenormalAttr();
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the DenormalModeAttr attribute for
the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getDenormalModeAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "denormal";
}]
>
];
}


#endif // ARITH_OPS_INTERFACES
6 changes: 6 additions & 0 deletions mlir/include/mlir/IR/Matchers.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,12 @@ inline detail::constant_float_predicate_matcher m_NegInfFloat() {
}};
}

/// Matches a constant scalar / vector splat / tensor splat with denormal
/// values.
inline detail::constant_float_predicate_matcher m_isDenormalFloat() {
return {[](const APFloat &value) { return value.isDenormal(); }};
}

/// Matches a constant scalar / vector splat / tensor splat integer zero.
inline detail::constant_int_predicate_matcher m_Zero() {
return {[](const APInt &value) { return 0 == value; }};
Expand Down
67 changes: 47 additions & 20 deletions mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,49 @@ struct ConstrainedVectorConvertToLLVMPattern
}
};

template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough>
struct DenormalOpConversionToLLVMPattern
: public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
using VectorConvertToLLVMPattern<SourceOp, TargetOp,
AttrConvert>::VectorConvertToLLVMPattern;

LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: Here, we need a legalization step. LLVM provides a function-level
// attribute for denormal; here, we need to move this information from the
// operation to the function, making sure all the operations in the same
// function are consistent.
if (op.getDenormalModeAttr().getValue() != arith::DenormalMode::ieee)
return rewriter.notifyMatchFailure(
op, "only ieee denormal mode is supported at the moment");

StringRef arithDenormalAttrName = SourceOp::getDenormalModeAttrName();
op->removeAttr(arithDenormalAttrName);
return VectorConvertToLLVMPattern<SourceOp, TargetOp,
AttrConvert>::matchAndRewrite(op, adaptor,
rewriter);
chelini marked this conversation as resolved.
Show resolved Hide resolved
}
};

//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//

using AddFOpLowering =
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
arith::AttrConvertFastMathToLLVM>;
DenormalOpConversionToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
arith::AttrConvertFastMathToLLVM>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
arith::AttrConvertFastMathToLLVM>;
DenormalOpConversionToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
arith::AttrConvertFastMathToLLVM>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
Expand All @@ -83,38 +110,38 @@ using FPToSIOpLowering =
using FPToUIOpLowering =
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
using MaximumFOpLowering =
VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
arith::AttrConvertFastMathToLLVM>;
DenormalOpConversionToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
arith::AttrConvertFastMathToLLVM>;
using MaxNumFOpLowering =
VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
arith::AttrConvertFastMathToLLVM>;
DenormalOpConversionToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
arith::AttrConvertFastMathToLLVM>;
using MaxSIOpLowering =
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
using MaxUIOpLowering =
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
using MinimumFOpLowering =
VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
arith::AttrConvertFastMathToLLVM>;
DenormalOpConversionToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
arith::AttrConvertFastMathToLLVM>;
using MinNumFOpLowering =
VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
arith::AttrConvertFastMathToLLVM>;
DenormalOpConversionToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
arith::AttrConvertFastMathToLLVM>;
using MinSIOpLowering =
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
arith::AttrConvertFastMathToLLVM>;
DenormalOpConversionToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
arith::AttrConvertFastMathToLLVM>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
arith::AttrConvertOverflowToLLVM>;
using NegFOpLowering =
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
arith::AttrConvertFastMathToLLVM>;
DenormalOpConversionToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
arith::AttrConvertFastMathToLLVM>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
arith::AttrConvertFastMathToLLVM>;
DenormalOpConversionToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
arith::AttrConvertFastMathToLLVM>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
Expand All @@ -131,8 +158,8 @@ using ShRUIOpLowering =
using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
arith::AttrConvertFastMathToLLVM>;
DenormalOpConversionToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
arith::AttrConvertFastMathToLLVM>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
Expand Down
14 changes: 8 additions & 6 deletions mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -422,21 +422,23 @@ def TruncIShrUIMulIToMulUIExtended :
//===----------------------------------------------------------------------===//

// mulf(negf(x), negf(y)) -> mulf(x,y)
// (retain fastmath flags of original mulf)
// (retain fastmath flags and denormal mode of the original divf)
def MulFOfNegF :
Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
(Arith_MulFOp $x, $y, $fmf),
Pat<(Arith_MulFOp (Arith_NegFOp $x, $_, $_),
(Arith_NegFOp $y, $_, $_), $fmf, $mode),
(Arith_MulFOp $x, $y, $fmf, $mode),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;

//===----------------------------------------------------------------------===//
// DivFOp
//===----------------------------------------------------------------------===//

// divf(negf(x), negf(y)) -> divf(x,y)
// (retain fastmath flags of original divf)
// (retain fastmath flags and denormal mode of the original divf)
def DivFOfNegF :
Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
(Arith_DivFOp $x, $y, $fmf),
Pat<(Arith_DivFOp (Arith_NegFOp $x, $_, $_),
(Arith_NegFOp $y, $_, $_), $fmf, $mode),
(Arith_DivFOp $x, $y, $fmf, $mode),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;

#endif // ARITH_PATTERNS
10 changes: 9 additions & 1 deletion mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//

OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
/// negf(negf(x)) -> x
// negf(negf(x)) -> x
if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
return op.getOperand();
return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
Expand Down Expand Up @@ -982,6 +982,14 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
return getLhs();

// Simplifies subf(x, rhs) to x if the following conditions are met:
// 1. `rhs` is a denormal floating-point value.
// 2. The denormal mode for the operation is set to positive zero.
bool isPositiveZeroMode =
getDenormalModeAttr().getValue() == DenormalMode::positive_zero;
if (isPositiveZeroMode && matchPattern(adaptor.getRhs(), m_isDenormalFloat()))
return getLhs();

return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return a - b; });
Expand Down
16 changes: 9 additions & 7 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1498,15 +1498,17 @@ static Operation *findPayloadOp(Block *body, bool initFirst = false) {

void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
SmallVector<StringRef> elidedAttrs;
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
for (const auto &attr : payloadOp->getAttrs()) {
auto fastAttr =
llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
attrToElide = attr.getName().str();
elidedAttrs.push_back(attrToElide);
break;
if (auto fastAttr = dyn_cast<arith::FastMathFlagsAttr>(attr.getValue())) {
if (fastAttr.getValue() == arith::FastMathFlags::none) {
elidedAttrs.push_back(attr.getName());
}
}
if (auto denormAttr = dyn_cast<arith::DenormalModeAttr>(attr.getValue())) {
if (denormAttr.getValue() == arith::DenormalMode::ieee) {
elidedAttrs.push_back(attr.getName());
}
}
}
p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
Expand Down
Loading
Loading