Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][llvm] Add icmp folder #65343

Merged
merged 3 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> {
// Set the $predicate index to -1 to indicate there is no matching operand
// and decrement the following indices.
list<int> llvmArgIndices = [-1, 0, 1];
let hasFolder = 1;
}

// Other floating-point operations.
Expand Down Expand Up @@ -1561,6 +1562,17 @@ def LLVM_ConstantOp
}]>
];

let extraClassDeclaration = [{
/// Whether the constant op can be constructed with a particular value and
/// type.
static bool isBuildableWith(Attribute value, Type type);

/// Build the constant op with `value` and `type` if possible, otherwise
/// returns null.
static ConstantOp materialize(OpBuilder &builder, Attribute value,
Type type, Location loc);
}];

let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down
65 changes: 58 additions & 7 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ static Type getI1SameShape(Type type) {
}

//===----------------------------------------------------------------------===//
// Printing, parsing and builder for LLVM::CmpOp.
// Printing, parsing, folding and builder for LLVM::CmpOp.
//===----------------------------------------------------------------------===//

void ICmpOp::print(OpAsmPrinter &p) {
Expand Down Expand Up @@ -175,6 +175,42 @@ ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
return parseCmpOp<FCmpPredicate>(parser, result);
}

/// Returns a scalar or vector boolean attribute of the given type.
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
auto boolAttr = BoolAttr::get(ctx, value);
ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
ShapedType shapedType = dyn_cast<ShapedType>(type);

if (!shapedType)
return boolAttr;
return DenseElementsAttr::get(shapedType, boolAttr);
}

OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
if (getPredicate() != ICmpPredicate::eq &&
getPredicate() != ICmpPredicate::ne)
return {};

// cmpi(eq/ne, x, x) -> true/false
if (getLhs() == getRhs())
return getBoolAttribute(getType(), getContext(),
getPredicate() == ICmpPredicate::eq);

// cmpi(eq/ne, alloca, null) -> false/true
if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<NullOp>())
return getBoolAttribute(getType(), getContext(),
getPredicate() == ICmpPredicate::ne);

// cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
if (getLhs().getDefiningOp<NullOp>() && getRhs().getDefiningOp<AllocaOp>()) {
Value lhs = getLhs();
Value rhs = getRhs();
getLhsMutable().assign(rhs);
getRhsMutable().assign(lhs);
return getResult();
}

return {};
}

//===----------------------------------------------------------------------===//
// Printing, parsing and verification for LLVM::AllocaOp.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2443,7 +2479,7 @@ Region *LLVMFuncOp::getCallableRegion() {
}

//===----------------------------------------------------------------------===//
// Verification for LLVM::ConstantOp.
// ConstantOp.
//===----------------------------------------------------------------------===//

LogicalResult LLVM::ConstantOp::verify() {
Expand Down Expand Up @@ -2503,6 +2539,25 @@ LogicalResult LLVM::ConstantOp::verify() {
return success();
}

bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
// The value's type must be the same as the provided type.
auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
return false;
// The value's type must be an LLVM compatible type.
if (!isCompatibleType(type))
return false;
// TODO: Add support for additional attributes kinds once needed.
return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
}

ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
Type type, Location loc) {
if (isBuildableWith(value, type))
return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value));
return nullptr;
}

// Constant op constant-folds to its value.
OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }

Expand Down Expand Up @@ -3097,11 +3152,7 @@ LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,

Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
// TODO: Accept more possible attributes. So far, only IntegerAttr may come
// up.
if (!isa<IntegerAttr>(value))
return nullptr;
return builder.create<LLVM::ConstantOp>(loc, type, value);
return LLVM::ConstantOp::materialize(builder, value, type, loc);
}

//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/LLVMIR/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
// RUN: mlir-opt --pass-pipeline='builtin.module(llvm.func(canonicalize{test-convergence}))' %s -split-input-file | FileCheck %s

// CHECK-LABEL: @fold_icmp_eq
llvm.func @fold_icmp_eq(%arg0 : i32) -> i1 {
// CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1
%0 = llvm.icmp "eq" %arg0, %arg0 : i32
// CHECK: llvm.return %[[C0]]
llvm.return %0 : i1
}

// CHECK-LABEL: @fold_icmp_ne
llvm.func @fold_icmp_ne(%arg0 : vector<2xi32>) -> vector<2xi1> {
// CHECK: %[[C0:.*]] = llvm.mlir.constant(dense<false> : vector<2xi1>) : vector<2xi1>
%0 = llvm.icmp "ne" %arg0, %arg0 : vector<2xi32>
// CHECK: llvm.return %[[C0]]
llvm.return %0 : vector<2xi1>
}

// CHECK-LABEL: @fold_icmp_alloca
llvm.func @fold_icmp_alloca() -> i1 {
// CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1
%c0 = llvm.mlir.null : !llvm.ptr
%c1 = arith.constant 1 : i64
%0 = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
%1 = llvm.icmp "ne" %c0, %0 : !llvm.ptr
// CHECK: llvm.return %[[C0]]
llvm.return %1 : i1
}

// -----

// CHECK-LABEL: fold_extractvalue
llvm.func @fold_extractvalue() -> i32 {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
Expand Down