diff --git a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h index 5d4774861bdfd3..6e617ef40a53d7 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h @@ -18,16 +18,18 @@ class Value; namespace affine { void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); -/// Compute whether the given values are equal. Return "failure" if equality -/// could not be determined. `value1`/`value2` must be index-typed. +/// Compute a constant delta of the given two values. Return "failure" if we +/// cannot determine a constant delta. `value1`/`value2` must be index-typed. /// -/// This function is similar to `ValueBoundsConstraintSet::areEqual`. To work -/// around limitations in `FlatLinearConstraints`, this function fully composes +/// This function is similar to +/// `ValueBoundsConstraintSet::computeConstantDistance`. To work around +/// limitations in `FlatLinearConstraints`, this function fully composes /// `value1` and `value2` (if they are the result of affine.apply ops) before /// populating the constraint set. The folding/composing logic can see /// opportunities for simplifications that the constraint set implementation /// cannot see. -FailureOr fullyComposeAndCheckIfEqual(Value value1, Value value2); +FailureOr fullyComposeAndComputeConstantDelta(Value value1, + Value value2); } // namespace affine } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index fc0c80036ff79a..9ab20e20d97542 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -105,16 +105,23 @@ bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read); /// op. bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite); -/// Same behavior as `isDisjointTransferSet` but doesn't require the operations -/// to have the same tensor/memref. This allows comparing operations accessing -/// different tensors. +/// Return true if we can prove that the transfer operations access disjoint +/// memory, without requring the accessed tensor/memref to be the same. +/// +/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values +/// via ValueBoundsOpInterface. bool isDisjointTransferIndices(VectorTransferOpInterface transferA, - VectorTransferOpInterface transferB); + VectorTransferOpInterface transferB, + bool testDynamicValueUsingBounds = false); /// Return true if we can prove that the transfer operations access disjoint -/// memory. +/// memory, requiring the operations to access the same tensor/memref. +/// +/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values +/// via ValueBoundsOpInterface. bool isDisjointTransferSet(VectorTransferOpInterface transferA, - VectorTransferOpInterface transferB); + VectorTransferOpInterface transferB, + bool testDynamicValueUsingBounds = false); /// Return the result value of reducing two scalar/vector values with the /// corresponding arith operation. diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 2687d79aec68eb..8f11c563e0cbd9 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -176,6 +176,16 @@ class ValueBoundsConstraintSet { presburger::BoundType type, AffineMap map, ValueDimList mapOperands, StopConditionFn stopCondition = nullptr, bool closedUB = false); + /// Compute a constant delta between the given two values. Return "failure" + /// if a constant delta could not be determined. + /// + /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are + /// index-typed. + static FailureOr + computeConstantDelta(Value value1, Value value2, + std::optional dim1 = std::nullopt, + std::optional dim2 = std::nullopt); + /// Compute whether the given values/dimensions are equal. Return "failure" if /// equality could not be determined. /// diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index d47c8eb8ccb427..e0c3abe7a0f71d 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -103,8 +103,8 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels( }); } -FailureOr mlir::affine::fullyComposeAndCheckIfEqual(Value value1, - Value value2) { +FailureOr +mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) { assert(value1.getType().isIndex() && "expected index type"); assert(value2.getType().isIndex() && "expected index type"); @@ -123,9 +123,6 @@ FailureOr mlir::affine::fullyComposeAndCheckIfEqual(Value value1, ValueDimList valueDims; for (Value v : mapOperands) valueDims.push_back({v, std::nullopt}); - FailureOr bound = ValueBoundsConstraintSet::computeConstantBound( + return ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType::EQ, map, valueDims); - if (failed(bound)) - return failure(); - return *bound == 0; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 221bec713b38aa..cbb2c507de69f9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -173,16 +173,16 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { if (auto transferWriteUse = dyn_cast(use.getOwner())) { if (!vector::isDisjointTransferSet( - cast(transferWrite.getOperation()), - cast( - transferWriteUse.getOperation()))) + cast(*transferWrite), + cast(*transferWriteUse), + /*testDynamicValueUsingBounds=*/true)) return WalkResult::advance(); } else if (auto transferReadUse = dyn_cast(use.getOwner())) { if (!vector::isDisjointTransferSet( - cast(transferWrite.getOperation()), - cast( - transferReadUse.getOperation()))) + cast(*transferWrite), + cast(*transferReadUse), + /*testDynamicValueUsingBounds=*/true)) return WalkResult::advance(); } else { // Unknown use, we cannot prove that it doesn't alias with the diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt index 9ec919423b3428..70f3fa8c297d4b 100644 --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorDialect MLIRVectorAttributesIncGen LINK_LIBS PUBLIC + MLIRAffineDialect MLIRArithDialect MLIRControlFlowInterfaces MLIRDataLayoutInterfaces @@ -22,5 +23,6 @@ add_mlir_dialect_library(MLIRVectorDialect MLIRMemRefDialect MLIRSideEffectInterfaces MLIRTensorDialect + MLIRValueBoundsOpInterface MLIRVectorInterfaces ) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 044b6cc07d3d62..68a5cf209f2fb4 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -30,6 +31,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -168,39 +170,76 @@ bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write, } bool mlir::vector::isDisjointTransferIndices( - VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) { + VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, + bool testDynamicValueUsingBounds) { // For simplicity only look at transfer of same type. if (transferA.getVectorType() != transferB.getVectorType()) return false; unsigned rankOffset = transferA.getLeadingShapedRank(); for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { - auto indexA = getConstantIntValue(transferA.indices()[i]); - auto indexB = getConstantIntValue(transferB.indices()[i]); - // If any of the indices are dynamic we cannot prove anything. - if (!indexA.has_value() || !indexB.has_value()) - continue; + Value indexA = transferA.indices()[i]; + Value indexB = transferB.indices()[i]; + std::optional cstIndexA = getConstantIntValue(indexA); + std::optional cstIndexB = getConstantIntValue(indexB); if (i < rankOffset) { // For leading dimensions, if we can prove that index are different we // know we are accessing disjoint slices. - if (*indexA != *indexB) - return true; + if (cstIndexA.has_value() && cstIndexB.has_value()) { + if (*cstIndexA != *cstIndexB) + return true; + continue; + } + if (testDynamicValueUsingBounds) { + // First try to see if we can fully compose and simplify the affine + // expression as a fast track. + FailureOr delta = + affine::fullyComposeAndComputeConstantDelta(indexA, indexB); + if (succeeded(delta) && *delta != 0) + return true; + + FailureOr testEqual = + ValueBoundsConstraintSet::areEqual(indexA, indexB); + if (succeeded(testEqual) && !testEqual.value()) + return true; + } } else { // For this dimension, we slice a part of the memref we need to make sure // the intervals accessed don't overlap. - int64_t distance = std::abs(*indexA - *indexB); - if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) - return true; + int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset); + if (cstIndexA.has_value() && cstIndexB.has_value()) { + int64_t distance = std::abs(*cstIndexA - *cstIndexB); + if (distance >= vectorDim) + return true; + continue; + } + if (testDynamicValueUsingBounds) { + // First try to see if we can fully compose and simplify the affine + // expression as a fast track. + FailureOr delta = + affine::fullyComposeAndComputeConstantDelta(indexA, indexB); + if (succeeded(delta) && std::abs(*delta) >= vectorDim) + return true; + + FailureOr computeDelta = + ValueBoundsConstraintSet::computeConstantDelta(indexA, indexB); + if (succeeded(computeDelta)) { + if (std::abs(computeDelta.value()) >= vectorDim) + return true; + } + } } } return false; } bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA, - VectorTransferOpInterface transferB) { + VectorTransferOpInterface transferB, + bool testDynamicValueUsingBounds) { if (transferA.source() != transferB.source()) return false; - return isDisjointTransferIndices(transferA, transferB); + return isDisjointTransferIndices(transferA, transferB, + testDynamicValueUsingBounds); } // Helper to iterate over n-D vector slice elements. Calculate the next diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 603b88f11c8e00..a5f1b28152b9bd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -142,7 +142,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { // Don't need to consider disjoint accesses. if (vector::isDisjointTransferSet( cast(write.getOperation()), - cast(transferOp.getOperation()))) + cast(transferOp.getOperation()), + /*testDynamicValueUsingBounds=*/true)) continue; } blockingAccesses.push_back(user); @@ -217,7 +218,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { // the write. if (vector::isDisjointTransferSet( cast(write.getOperation()), - cast(read.getOperation()))) + cast(read.getOperation()), + /*testDynamicValueUsingBounds=*/true)) continue; if (write.getSource() == read.getSource() && dominators.dominates(write, read) && checkSameValueRAW(write, read)) { diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index c00ee0315a9639..ff941115219f68 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -484,25 +484,32 @@ FailureOr ValueBoundsConstraintSet::computeConstantBound( return failure(); } -FailureOr -ValueBoundsConstraintSet::areEqual(Value value1, Value value2, - std::optional dim1, - std::optional dim2) { +FailureOr +ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2, + std::optional dim1, + std::optional dim2) { #ifndef NDEBUG assertValidValueDim(value1, dim1); assertValidValueDim(value2, dim2); #endif // NDEBUG - // Subtract the two values/dimensions from each other. If the result is 0, - // both are equal. Builder b(value1.getContext()); AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, b.getAffineDimExpr(0) - b.getAffineDimExpr(1)); - FailureOr bound = computeConstantBound( - presburger::BoundType::EQ, map, {{value1, dim1}, {value2, dim2}}); - if (failed(bound)) + return computeConstantBound(presburger::BoundType::EQ, map, + {{value1, dim1}, {value2, dim2}}); +} + +FailureOr +ValueBoundsConstraintSet::areEqual(Value value1, Value value2, + std::optional dim1, + std::optional dim2) { + // Subtract the two values/dimensions from each other. If the result is 0, + // both are equal. + FailureOr delta = computeConstantDelta(value1, value2, dim1, dim2); + if (failed(delta)) return failure(); - return *bound == 0; + return *delta == 0; } ValueBoundsConstraintSet::BoundBuilder & diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 7d0c3648c344b1..11bf4b58b95c82 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -872,3 +872,135 @@ transform.sequence failures(propagate) { transform.structured.hoist_redundant_vector_transfers %0 : (!transform.any_op) -> !transform.any_op } + +// ----- + +// Test that we can hoist out 1-D read-write pairs whose indices are dynamic values. + +// CHECK: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 1)> +// CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)> + +// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic +// CHECK-SAME: (%[[BUFFER:.+]]: memref, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index) + +// CHECK: %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]] +// CHECK: %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]] +// CHECK: %2 = vector.transfer_read %[[BUFFER]][%[[I0]], %[[I0]]] +// CHECK: %3 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[I0]]] +// CHECK: %4 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]] +// CHECK-COUNT-2: scf.for %{{.+}} = {{.+}} -> (vector<4xf32>, vector<4xf32>, vector<4xf32>) +// CHECK-COUNT-3: "some_use" +// CHECK-COUNT-2: scf.yield {{.+}} : vector<4xf32>, vector<4xf32>, vector<4xf32> +// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]] +// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[I0]]] +// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]] + +func.func @hoist_vector_transfer_pairs_disjoint_dynamic( + %buffer: memref, %lb : index, %ub : index, %step: index, %i0 : index) { + %cst = arith.constant 0.0 : f32 + %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0) + %i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0) + + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref, vector<4xf32> + // Disjoint leading dim + %r1 = vector.transfer_read %buffer[%i1, %i0], %cst: memref, vector<4xf32> + // Non-overlap trailing dim + %r2 = vector.transfer_read %buffer[%i1, %i2], %cst: memref, vector<4xf32> + %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32> + %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32> + %u2 = "some_use"(%r2) : (vector<4xf32>) -> vector<4xf32> + vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref + vector.transfer_write %u1, %buffer[%i1, %i0] : vector<4xf32>, memref + vector.transfer_write %u2, %buffer[%i1, %i2] : vector<4xf32>, memref + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op +} + +// ----- + +// Test that we cannot hoist out read-write pairs whose indices are overlapping. + +// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_overlapping_dynamic +// CHECK-COUNT-2: scf.for +// CHECK-COUNT-2: vector.transfer_read +// CHECK-COUNT-2: vector.transfer_write + +func.func @hoist_vector_transfer_pairs_overlapping_dynamic( + %buffer: memref, %lb : index, %ub : index, %step: index, %i0 : index) { + %cst = arith.constant 0.0 : f32 + %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0) + + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref, vector<4xf32> + // Overlapping range with the above + %r1 = vector.transfer_read %buffer[%i0, %i1], %cst: memref, vector<4xf32> + %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32> + %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32> + vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref + vector.transfer_write %u1, %buffer[%i0, %i1] : vector<4xf32>, memref + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op +} + +// ----- + +// Test that we can hoist out 2-D read-write pairs whose indices are dynamic values. + +// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic +// CHECK-COUNT-3: vector.transfer_read +// CHECK-COUNT-2: %{{.+}}:3 = scf.for {{.+}} -> (vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32>) +// CHECK-COUNT-2: scf.yield {{.+}} : vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32> +// CHECK-COUNT-3: vector.transfer_write +// CHECK: return + +func.func @hoist_vector_transfer_pairs_disjoint_dynamic( + %buffer: memref, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) { + %cst = arith.constant 0.0 : f32 + %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1) + %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1) + %i4 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 16)>(%i1) + + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %r0 = vector.transfer_read %buffer[%i0, %i2], %cst: memref, vector<16x8xf32> + %r1 = vector.transfer_read %buffer[%i0, %i3], %cst: memref, vector<16x8xf32> + %r2 = vector.transfer_read %buffer[%i0, %i4], %cst: memref, vector<16x8xf32> + %u0 = "some_use"(%r0) : (vector<16x8xf32>) -> vector<16x8xf32> + %u1 = "some_use"(%r1) : (vector<16x8xf32>) -> vector<16x8xf32> + %u2 = "some_use"(%r2) : (vector<16x8xf32>) -> vector<16x8xf32> + vector.transfer_write %u2, %buffer[%i0, %i4] : vector<16x8xf32>, memref + vector.transfer_write %u1, %buffer[%i0, %i3] : vector<16x8xf32>, memref + vector.transfer_write %u0, %buffer[%i0, %i2] : vector<16x8xf32>, memref + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op +} diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir index f43367ab4aeba7..13957af014b89e 100644 --- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir +++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir @@ -256,3 +256,107 @@ func.func @collapse_shape(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) { } return } + +// CHECK-LABEL: func @forward_dead_store_dynamic_same_index +// CHECK-NOT: vector.transfer_write +// CHECK-NOT: vector.transfer_read +// CHECK: scf.for +// CHECK: } +// CHECK: vector.transfer_write +// CHECK: return +func.func @forward_dead_store_dynamic_same_index( + %buffer : memref, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i : index) { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + vector.transfer_write %v0, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref + // The following transfer op reads/writes to the same address so that we can forward. + %0 = vector.transfer_read %buffer[%i, %i], %cf0 {in_bounds = [true]} : memref, vector<4xf32> + %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) { + %1 = arith.addf %acc, %acc : vector<4xf32> + scf.yield %1 : vector<4xf32> + } + vector.transfer_write %x, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref + return +} + +// CHECK-LABEL: func @dont_forward_dead_store_dynamic_overlap +// CHECK-COUNT-2: vector.transfer_write +// CHECK: vector.transfer_read +// CHECK: scf.for +// CHECK: } +// CHECK: vector.transfer_write +// CHECK: return +func.func @dont_forward_dead_store_dynamic_overlap( + %buffer : memref, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0) + vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref + // The following transfer op writes to an overlapping range so we cannot forward. + vector.transfer_write %v0, %buffer[%i0, %i1] {in_bounds = [true]} : vector<4xf32>, memref + %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref, vector<4xf32> + %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) { + %1 = arith.addf %acc, %acc : vector<4xf32> + scf.yield %1 : vector<4xf32> + } + vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref + return +} + +// CHECK-LABEL: func @forward_dead_store_dynamic_non_overlap_leading_dim +// CHECK: vector.transfer_write +// CHECK-NOT: vector.transfer_write +// CHECK-NOT: vector.transfer_read +// CHECK: scf.for +// CHECK: } +// CHECK: vector.transfer_write +// CHECK: return +func.func @forward_dead_store_dynamic_non_overlap_leading_dim( + %buffer : memref, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0) + vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref + // The following transfer op writes to an non-overlapping range so we can forward. + vector.transfer_write %v0, %buffer[%i1, %i0] {in_bounds = [true]} : vector<4xf32>, memref + %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref, vector<4xf32> + %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) { + %1 = arith.addf %acc, %acc : vector<4xf32> + scf.yield %1 : vector<4xf32> + } + vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref + return +} + +// CHECK-LABEL: func @forward_dead_store_dynamic_non_overlap_trailing_dim +// CHECK: vector.transfer_write +// CHECK-NOT: vector.transfer_write +// CHECK-NOT: vector.transfer_read +// CHECK: scf.for +// CHECK: } +// CHECK: vector.transfer_write +// CHECK: return +func.func @forward_dead_store_dynamic_non_overlap_trailing_dim( + %buffer : memref, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %i1 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0) + vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref + // The following transfer op writes to an non-overlapping range so we can forward. + vector.transfer_write %v0, %buffer[%i0, %i1] {in_bounds = [true]} : vector<4xf32>, memref + %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref, vector<4xf32> + %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) { + %1 = arith.addf %acc, %acc : vector<4xf32> + scf.yield %1 : vector<4xf32> + } + vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref + return +} diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index 6e3c3dff759a2e..2f1631cbdb02e0 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -187,20 +187,26 @@ static LogicalResult testEquality(func::FuncOp funcOp) { op->emitOpError("invalid op"); return WalkResult::skip(); } - FailureOr equal = failure(); if (op->hasAttr("compose")) { - equal = affine::fullyComposeAndCheckIfEqual(op->getOperand(0), - op->getOperand(1)); - } else { - equal = ValueBoundsConstraintSet::areEqual(op->getOperand(0), - op->getOperand(1)); - } - if (failed(equal)) { - op->emitError("could not determine equality"); - } else if (*equal) { - op->emitRemark("equal"); + FailureOr equal = affine::fullyComposeAndComputeConstantDelta( + op->getOperand(0), op->getOperand(1)); + if (failed(equal)) { + op->emitError("could not determine equality"); + } else if (*equal == 0) { + op->emitRemark("equal"); + } else { + op->emitRemark("different"); + } } else { - op->emitRemark("different"); + FailureOr equal = ValueBoundsConstraintSet::areEqual( + op->getOperand(0), op->getOperand(1)); + if (failed(equal)) { + op->emitError("could not determine equality"); + } else if (*equal) { + op->emitRemark("equal"); + } else { + op->emitRemark("different"); + } } } return WalkResult::advance(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index de13e03807e821..63f9cdafce88b9 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4422,6 +4422,7 @@ cc_library( ]), includes = ["include"], deps = [ + ":AffineDialect", ":ArithDialect", ":ArithUtils", ":ControlFlowInterfaces", @@ -4436,6 +4437,7 @@ cc_library( ":SideEffectInterfaces", ":Support", ":TensorDialect", + ":ValueBoundsOpInterface", ":VectorAttributesIncGen", ":VectorDialectIncGen", ":VectorInterfaces",