diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 2b4c8b609cfdd..7ceec72144eb5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -139,6 +139,7 @@ def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> { // Set the $predicate index to -1 to indicate there is no matching operand // and decrement the following indices. list llvmArgIndices = [-1, 0, 1]; + let hasFolder = 1; } // Other floating-point operations. @@ -1561,6 +1562,17 @@ def LLVM_ConstantOp }]> ]; + let extraClassDeclaration = [{ + /// Whether the constant op can be constructed with a particular value and + /// type. + static bool isBuildableWith(Attribute value, Type type); + + /// Build the constant op with `value` and `type` if possible, otherwise + /// returns null. + static ConstantOp materialize(OpBuilder &builder, Attribute value, + Type type, Location loc); + }]; + let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index fd0d2b3fb3c1a..ae01f7c462152 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -98,7 +98,7 @@ static Type getI1SameShape(Type type) { } //===----------------------------------------------------------------------===// -// Printing, parsing and builder for LLVM::CmpOp. +// Printing, parsing, folding and builder for LLVM::CmpOp. //===----------------------------------------------------------------------===// void ICmpOp::print(OpAsmPrinter &p) { @@ -175,6 +175,42 @@ ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { return parseCmpOp(parser, result); } +/// Returns a scalar or vector boolean attribute of the given type. +static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { + auto boolAttr = BoolAttr::get(ctx, value); + ShapedType shapedType = dyn_cast(type); + if (!shapedType) + return boolAttr; + return DenseElementsAttr::get(shapedType, boolAttr); +} + +OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) { + if (getPredicate() != ICmpPredicate::eq && + getPredicate() != ICmpPredicate::ne) + return {}; + + // cmpi(eq/ne, x, x) -> true/false + if (getLhs() == getRhs()) + return getBoolAttribute(getType(), getContext(), + getPredicate() == ICmpPredicate::eq); + + // cmpi(eq/ne, alloca, null) -> false/true + if (getLhs().getDefiningOp() && getRhs().getDefiningOp()) + return getBoolAttribute(getType(), getContext(), + getPredicate() == ICmpPredicate::ne); + + // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null) + if (getLhs().getDefiningOp() && getRhs().getDefiningOp()) { + Value lhs = getLhs(); + Value rhs = getRhs(); + getLhsMutable().assign(rhs); + getRhsMutable().assign(lhs); + return getResult(); + } + + return {}; +} + //===----------------------------------------------------------------------===// // Printing, parsing and verification for LLVM::AllocaOp. //===----------------------------------------------------------------------===// @@ -2443,7 +2479,7 @@ Region *LLVMFuncOp::getCallableRegion() { } //===----------------------------------------------------------------------===// -// Verification for LLVM::ConstantOp. +// ConstantOp. //===----------------------------------------------------------------------===// LogicalResult LLVM::ConstantOp::verify() { @@ -2503,6 +2539,25 @@ LogicalResult LLVM::ConstantOp::verify() { return success(); } +bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) { + // The value's type must be the same as the provided type. + auto typedAttr = dyn_cast(value); + if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type)) + return false; + // The value's type must be an LLVM compatible type. + if (!isCompatibleType(type)) + return false; + // TODO: Add support for additional attributes kinds once needed. + return isa(value); +} + +ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value, + Type type, Location loc) { + if (isBuildableWith(value, type)) + return builder.create(loc, cast(value)); + return nullptr; +} + // Constant op constant-folds to its value. OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); } @@ -3097,11 +3152,7 @@ LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - // TODO: Accept more possible attributes. So far, only IntegerAttr may come - // up. - if (!isa(value)) - return nullptr; - return builder.create(loc, type, value); + return LLVM::ConstantOp::materialize(builder, value, type, loc); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir index 6b2cac14f2985..3e7f689bdc03e 100644 --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -1,5 +1,34 @@ // RUN: mlir-opt --pass-pipeline='builtin.module(llvm.func(canonicalize{test-convergence}))' %s -split-input-file | FileCheck %s +// CHECK-LABEL: @fold_icmp_eq +llvm.func @fold_icmp_eq(%arg0 : i32) -> i1 { + // CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1 + %0 = llvm.icmp "eq" %arg0, %arg0 : i32 + // CHECK: llvm.return %[[C0]] + llvm.return %0 : i1 +} + +// CHECK-LABEL: @fold_icmp_ne +llvm.func @fold_icmp_ne(%arg0 : vector<2xi32>) -> vector<2xi1> { + // CHECK: %[[C0:.*]] = llvm.mlir.constant(dense : vector<2xi1>) : vector<2xi1> + %0 = llvm.icmp "ne" %arg0, %arg0 : vector<2xi32> + // CHECK: llvm.return %[[C0]] + llvm.return %0 : vector<2xi1> +} + +// CHECK-LABEL: @fold_icmp_alloca +llvm.func @fold_icmp_alloca() -> i1 { + // CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1 + %c0 = llvm.mlir.null : !llvm.ptr + %c1 = arith.constant 1 : i64 + %0 = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr + %1 = llvm.icmp "ne" %c0, %0 : !llvm.ptr + // CHECK: llvm.return %[[C0]] + llvm.return %1 : i1 +} + +// ----- + // CHECK-LABEL: fold_extractvalue llvm.func @fold_extractvalue() -> i32 { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32