Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][arith] Rename AtomicRMWKind's maxfmaximumf, minfminimumf #66135

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,25 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
let cppNamespace = "::mlir::arith";
}

def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 3>;
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 6>;
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;

def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
ATOMIC_RMW_KIND_ANDI]> {
let cppNamespace = "::mlir::arith";
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1594,13 +1594,13 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::add;
case arith::AtomicRMWKind::assign:
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maxf:
case arith::AtomicRMWKind::maximumf:
return LLVM::AtomicBinOp::fmax;
case arith::AtomicRMWKind::maxs:
return LLVM::AtomicBinOp::max;
case arith::AtomicRMWKind::maxu:
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minf:
case arith::AtomicRMWKind::minimumf:
return LLVM::AtomicBinOp::fmin;
case arith::AtomicRMWKind::mins:
return LLVM::AtomicBinOp::min;
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
.Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
.Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
.Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
.Case([](arith::MinimumFOp) { return arith::AtomicRMWKind::minf; })
.Case([](arith::MaximumFOp) { return arith::AtomicRMWKind::maxf; })
.Case(
[](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; })
.Case(
[](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; })
.Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
.Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
.Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2369,7 +2369,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc,
bool useOnlyFiniteValue) {
switch (kind) {
case AtomicRMWKind::maxf: {
case AtomicRMWKind::maximumf: {
const llvm::fltSemantics &semantic =
llvm::cast<FloatType>(resultType).getFloatSemantics();
APFloat identity = useOnlyFiniteValue
Expand All @@ -2390,7 +2390,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
return builder.getIntegerAttr(
resultType, APInt::getSignedMinValue(
llvm::cast<IntegerType>(resultType).getWidth()));
case AtomicRMWKind::minf: {
case AtomicRMWKind::minimumf: {
const llvm::fltSemantics &semantic =
llvm::cast<FloatType>(resultType).getFloatSemantics();
APFloat identity = useOnlyFiniteValue
Expand Down Expand Up @@ -2426,8 +2426,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
// Floating-point operations.
.Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maxf; })
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minf; })
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
Expand Down Expand Up @@ -2482,9 +2482,9 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return builder.create<arith::MulFOp>(loc, lhs, rhs);
case AtomicRMWKind::muli:
return builder.create<arith::MulIOp>(loc, lhs, rhs);
case AtomicRMWKind::maxf:
case AtomicRMWKind::maximumf:
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
case AtomicRMWKind::minf:
case AtomicRMWKind::minimumf:
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
case AtomicRMWKind::maxs:
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2549,9 +2549,9 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
dims.erase(dims.begin() + reductionDim);
// Step 1: Compute max along dim.
Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
Value neutralForMaxF =
arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc,
/*useOnlyFiniteValue=*/true);
Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
elementType, b, loc,
/*useOnlyFiniteValue=*/true);
Value neutralForMaxFInit =
b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
.result();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3402,8 +3402,8 @@ LogicalResult AtomicRMWOp::verify() {
"expects the number of subscripts to be equal to memref rank");
switch (getKind()) {
case arith::AtomicRMWKind::addf:
case arith::AtomicRMWKind::maxf:
case arith::AtomicRMWKind::minf:
case arith::AtomicRMWKind::maximumf:
case arith::AtomicRMWKind::minimumf:
case arith::AtomicRMWKind::mulf:
if (!llvm::isa<FloatType>(getValue().getType()))
return emitOpError() << "with kind '"
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace {
/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
/// `memref.generic_atomic_rmw` with the expanded code.
///
/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
/// %x = atomic_rmw "maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
///
/// will be lowered to
///
Expand All @@ -54,10 +54,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
PatternRewriter &rewriter) const final {
arith::CmpFPredicate predicate;
switch (op.getKind()) {
case arith::AtomicRMWKind::maxf:
case arith::AtomicRMWKind::maximumf:
predicate = arith::CmpFPredicate::OGT;
break;
case arith::AtomicRMWKind::minf:
case arith::AtomicRMWKind::minimumf:
predicate = arith::CmpFPredicate::OLT;
break;
default:
Expand Down Expand Up @@ -137,8 +137,8 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
[](memref::AtomicRMWOp op) {
return op.getKind() != arith::AtomicRMWKind::maxf &&
op.getKind() != arith::AtomicRMWKind::minf;
return op.getKind() != arith::AtomicRMWKind::maximumf &&
op.getKind() != arith::AtomicRMWKind::minimumf;
});
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
case arith::AtomicRMWKind::muli:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MUL, vector);
case arith::AtomicRMWKind::minf:
case arith::AtomicRMWKind::minimumf:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MINF, vector);
case arith::AtomicRMWKind::mins:
Expand All @@ -502,7 +502,7 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
case arith::AtomicRMWKind::minu:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MINUI, vector);
case arith::AtomicRMWKind::maxf:
case arith::AtomicRMWKind::maximumf:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MAXF, vector);
case arith::AtomicRMWKind::maxs:
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Affine/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {

func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = memref.alloc() : memref<100x100xi32>
%1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minf") -> (f32) {
%1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minimumf") -> (f32) {
%2 = affine.load %0[%i, %j] : memref<100x100xi32>
// expected-error@+1 {{types mismatch between yield op and its parent}}
affine.yield %2 : i32
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Affine/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ func.func @valid_symbol_affine_scope(%n : index, %A : memref<?xf32>) {
func.func @parallel(%A : memref<100x100xf32>, %N : index) {
// CHECK: affine.parallel (%[[I0:.*]], %[[J0:.*]]) = (0, 0) to (symbol(%[[N]]), 100) step (10, 10)
affine.parallel (%i0, %j0) = (0, 0) to (symbol(%N), 100) step (10, 10) {
// CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) reduce ("minf", "maxf") -> (f32, f32)
%0:2 = affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) reduce ("minf", "maxf") -> (f32, f32) {
// CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) reduce ("minimumf", "maximumf") -> (f32, f32)
%0:2 = affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) reduce ("minimumf", "maximumf") -> (f32, f32) {
%2 = affine.load %A[%i0 + %i0, %j0 + %j1] : memref<100x100xf32>
affine.yield %2, %2 : f32, f32
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/MemRef/expand-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// CHECK-LABEL: func @atomic_rmw_to_generic
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
%x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
%x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
Expand Down