-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IntRangeInference] Infer values for {memref,tensor}.dim #122945
[mlir][IntRangeInference] Infer values for {memref,tensor}.dim #122945
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesImplement the integer range inference niterface for memref.dim and tetnor.dim using shared code. The inference will infer the Full diff: https://github.com/llvm/llvm-project/pull/122945.diff 13 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 72463dca715ca3..ac383ab46e7a50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -17,6 +17,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index a0d8d34f38237a..c3ee3968abc16d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -577,7 +578,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
MemRefsNormalizable,
ConditionallySpeculatable, NoMemoryEffect,
- ShapedDimOpInterface]> {
+ ShapedDimOpInterface,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
let summary = "dimension index operation";
let description = [{
The `dim` operation takes a memref and a dimension operand of type `index`.
@@ -1675,7 +1677,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}]>,
// Builder that infers the result layout map. The result shape must be
- // specified. Otherwise, the op may be ambiguous. The output shape for
+ // specified. Otherwise, the op may be ambiguous. The output shape for
// the op will be inferred using the inferOutputShape() method.
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation)>,
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 0a21c9922b223b..bd96337a55407a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -18,6 +18,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 812ac209845020..38874513a4cc00 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -195,7 +196,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
def Tensor_DimOp : Tensor_Op<"dim", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
ConditionallySpeculatable, NoMemoryEffect,
- ShapedDimOpInterface]> {
+ ShapedDimOpInterface,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
let summary = "dimension index operation";
let description = [{
The `tensor.dim` operation takes a tensor and a dimension operand of type
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 3988a8826498a9..e46358ccfc46f7 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -20,6 +20,8 @@
#include <optional>
namespace mlir {
+class ShapedDimOpInterface;
+
namespace intrange {
/// Function that performs inference on an array of `ConstantIntRanges`,
/// abstracted away here to permit writing the function that handles both
@@ -143,6 +145,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
const ConstantIntRanges &lhs,
const ConstantIntRanges &rhs);
+/// Returns the integer range for the result of a `ShapedDimOpInterface` given
+/// the optional inferred ranges for the `dimension` index `maybeDim`. When a
+/// dynamic dimension is encountered, returns [0, signed_max(type(result))].
+ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
+ const IntegerValueRange &maybeDim);
+
} // namespace intrange
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 845914ebd107a2..734294bd014c6e 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -16,6 +16,8 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRControlFlowInterfaces
MLIRDialect
MLIRDialectUtils
+ MLIRInferIntRangeCommon
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9aae46a5c288dc..f0aee7a68e0bff 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -915,6 +916,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
return Speculation::Speculatable;
}
+void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
+ SetIntLatticeFn setResultRange) {
+ setResultRange(getResult(),
+ intrange::inferShapedDimOpInterface(*this, argRanges[1]));
+}
+
/// Return a map with key being elements in `vals` and data being number of
/// occurences of it. Use std::map, since the `vals` here are strides and the
/// dynamic stride value is the same as the tombstone value for
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index cfdd3847761a49..5425615dac3932 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -26,6 +26,8 @@ add_mlir_dialect_library(MLIRTensorDialect
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRIR
+ MLIRInferIntRangeCommon
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRParallelCombiningOpInterface
MLIRShapedOpInterfaces
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 24a1d553153198..e0853cab60fb94 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -23,7 +23,9 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
@@ -782,6 +784,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
return Speculation::Speculatable;
}
+void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
+ SetIntLatticeFn setResultRange) {
+ setResultRange(getResult(),
+ intrange::inferShapedDimOpInterface(*this, argRanges[1]));
+}
+
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
index ece6c8e46ffea9..8c45f669974271 100644
--- a/mlir/lib/Interfaces/Utils/CMakeLists.txt
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRInferIntRangeCommon
MLIRInferIntRangeInterfaceIncGen
LINK_LIBS PUBLIC
+ MLIRShapedOpInterfaces
MLIRInferIntRangeInterface
MLIRIR
)
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 1eab4139488bdd..2f47939df5a022 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -14,6 +14,7 @@
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@@ -725,3 +726,46 @@ std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
return false;
return std::nullopt;
}
+
+//===----------------------------------------------------------------------===//
+// Shaped type dimension accessors / ShapedDimOpInterface
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
+ const IntegerValueRange &maybeDim) {
+ unsigned width =
+ ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
+ APInt zero = APInt::getZero(width);
+ APInt typeMax = APInt::getSignedMaxValue(width);
+
+ auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
+ if (!shapedTy.hasRank())
+ return ConstantIntRanges::fromSigned(zero, typeMax);
+
+ int64_t rank = shapedTy.getRank();
+ int64_t minDim = 0;
+ int64_t maxDim = rank - 1;
+ if (!maybeDim.isUninitialized()) {
+ const ConstantIntRanges &dim = maybeDim.getValue();
+ minDim = std::max(minDim, dim.smin().getSExtValue());
+ maxDim = std::min(maxDim, dim.smax().getSExtValue());
+ }
+
+ std::optional<ConstantIntRanges> result;
+ auto joinResult = [&](const ConstantIntRanges &thisResult) {
+ if (!result.has_value())
+ result = thisResult;
+ else
+ result = result->rangeUnion(thisResult);
+ };
+ for (int64_t i = minDim; i <= maxDim; ++i) {
+ int64_t length = shapedTy.getDimSize(i);
+
+ if (ShapedType::isDynamic(length))
+ joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
+ else
+ joinResult(ConstantIntRanges::constant(APInt(width, length)));
+ }
+ return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
+}
diff --git a/mlir/test/Dialect/MemRef/int-range-inference.mlir b/mlir/test/Dialect/MemRef/int-range-inference.mlir
new file mode 100644
index 00000000000000..e2aa487eaaa25b
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/int-range-inference.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @dim_const
+// CHECK: %[[ret:.+]] = arith.constant 3 : index
+// CHECK: return %[[ret]]
+func.func @dim_const(%m: memref<3x5xi32>) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = memref.dim %m, %c0 : memref<3x5xi32>
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static(%m: memref<3x5xi32>, %x: index) -> index {
+ %0 = memref.dim %m, %x : memref<3x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_dynamic(%m: memref<?x5xi32>) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = memref.dim %m, %c0 : memref<?x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_dynamic(%m: memref<?x5xi32>, %x: index) -> index {
+ %0 = memref.dim %m, %x : memref<?x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_some_omitting_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_some_omitting_dynamic(%m: memref<?x3x5xi32>, %x: index) -> index {
+ %c1 = arith.constant 1 : index
+ %0 = arith.maxsi %x, %c1 : index
+ %1 = memref.dim %m, %0 : memref<?x3x5xi32>
+ %2 = test.reflect_bounds %1 : index
+ return %2 : index
+}
diff --git a/mlir/test/Dialect/Tensor/int-range-inference.mlir b/mlir/test/Dialect/Tensor/int-range-inference.mlir
new file mode 100644
index 00000000000000..384ae781e0e330
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/int-range-inference.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @dim_const
+// CHECK: %[[ret:.+]] = arith.constant 3 : index
+// CHECK: return %[[ret]]
+func.func @dim_const(%m: tensor<3x5xi32>) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.dim %m, %c0 : tensor<3x5xi32>
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static(%m: tensor<3x5xi32>, %x: index) -> index {
+ %0 = tensor.dim %m, %x : tensor<3x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_dynamic(%m: tensor<?x5xi32>) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.dim %m, %c0 : tensor<?x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_dynamic(%m: tensor<?x5xi32>, %x: index) -> index {
+ %0 = tensor.dim %m, %x : tensor<?x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_some_omitting_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_some_omitting_dynamic(%m: tensor<?x3x5xi32>, %x: index) -> index {
+ %c1 = arith.constant 1 : index
+ %0 = arith.maxsi %x, %c1 : index
+ %1 = tensor.dim %m, %0 : tensor<?x3x5xi32>
+ %2 = test.reflect_bounds %1 : index
+ return %2 : index
+}
|
@llvm/pr-subscribers-mlir-tensor Author: Krzysztof Drewniak (krzysz00) ChangesImplement the integer range inference niterface for memref.dim and tetnor.dim using shared code. The inference will infer the Full diff: https://github.com/llvm/llvm-project/pull/122945.diff 13 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 72463dca715ca3..ac383ab46e7a50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -17,6 +17,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index a0d8d34f38237a..c3ee3968abc16d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -577,7 +578,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
MemRefsNormalizable,
ConditionallySpeculatable, NoMemoryEffect,
- ShapedDimOpInterface]> {
+ ShapedDimOpInterface,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
let summary = "dimension index operation";
let description = [{
The `dim` operation takes a memref and a dimension operand of type `index`.
@@ -1675,7 +1677,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}]>,
// Builder that infers the result layout map. The result shape must be
- // specified. Otherwise, the op may be ambiguous. The output shape for
+ // specified. Otherwise, the op may be ambiguous. The output shape for
// the op will be inferred using the inferOutputShape() method.
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation)>,
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 0a21c9922b223b..bd96337a55407a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -18,6 +18,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 812ac209845020..38874513a4cc00 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -195,7 +196,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
def Tensor_DimOp : Tensor_Op<"dim", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
ConditionallySpeculatable, NoMemoryEffect,
- ShapedDimOpInterface]> {
+ ShapedDimOpInterface,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
let summary = "dimension index operation";
let description = [{
The `tensor.dim` operation takes a tensor and a dimension operand of type
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 3988a8826498a9..e46358ccfc46f7 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -20,6 +20,8 @@
#include <optional>
namespace mlir {
+class ShapedDimOpInterface;
+
namespace intrange {
/// Function that performs inference on an array of `ConstantIntRanges`,
/// abstracted away here to permit writing the function that handles both
@@ -143,6 +145,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
const ConstantIntRanges &lhs,
const ConstantIntRanges &rhs);
+/// Returns the integer range for the result of a `ShapedDimOpInterface` given
+/// the optional inferred ranges for the `dimension` index `maybeDim`. When a
+/// dynamic dimension is encountered, returns [0, signed_max(type(result))].
+ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
+ const IntegerValueRange &maybeDim);
+
} // namespace intrange
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 845914ebd107a2..734294bd014c6e 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -16,6 +16,8 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRControlFlowInterfaces
MLIRDialect
MLIRDialectUtils
+ MLIRInferIntRangeCommon
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9aae46a5c288dc..f0aee7a68e0bff 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -915,6 +916,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
return Speculation::Speculatable;
}
+void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
+ SetIntLatticeFn setResultRange) {
+ setResultRange(getResult(),
+ intrange::inferShapedDimOpInterface(*this, argRanges[1]));
+}
+
/// Return a map with key being elements in `vals` and data being number of
/// occurences of it. Use std::map, since the `vals` here are strides and the
/// dynamic stride value is the same as the tombstone value for
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index cfdd3847761a49..5425615dac3932 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -26,6 +26,8 @@ add_mlir_dialect_library(MLIRTensorDialect
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRIR
+ MLIRInferIntRangeCommon
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRParallelCombiningOpInterface
MLIRShapedOpInterfaces
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 24a1d553153198..e0853cab60fb94 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -23,7 +23,9 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
@@ -782,6 +784,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
return Speculation::Speculatable;
}
+void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
+ SetIntLatticeFn setResultRange) {
+ setResultRange(getResult(),
+ intrange::inferShapedDimOpInterface(*this, argRanges[1]));
+}
+
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
index ece6c8e46ffea9..8c45f669974271 100644
--- a/mlir/lib/Interfaces/Utils/CMakeLists.txt
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRInferIntRangeCommon
MLIRInferIntRangeInterfaceIncGen
LINK_LIBS PUBLIC
+ MLIRShapedOpInterfaces
MLIRInferIntRangeInterface
MLIRIR
)
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 1eab4139488bdd..2f47939df5a022 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -14,6 +14,7 @@
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@@ -725,3 +726,46 @@ std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
return false;
return std::nullopt;
}
+
+//===----------------------------------------------------------------------===//
+// Shaped type dimension accessors / ShapedDimOpInterface
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
+ const IntegerValueRange &maybeDim) {
+ unsigned width =
+ ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
+ APInt zero = APInt::getZero(width);
+ APInt typeMax = APInt::getSignedMaxValue(width);
+
+ auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
+ if (!shapedTy.hasRank())
+ return ConstantIntRanges::fromSigned(zero, typeMax);
+
+ int64_t rank = shapedTy.getRank();
+ int64_t minDim = 0;
+ int64_t maxDim = rank - 1;
+ if (!maybeDim.isUninitialized()) {
+ const ConstantIntRanges &dim = maybeDim.getValue();
+ minDim = std::max(minDim, dim.smin().getSExtValue());
+ maxDim = std::min(maxDim, dim.smax().getSExtValue());
+ }
+
+ std::optional<ConstantIntRanges> result;
+ auto joinResult = [&](const ConstantIntRanges &thisResult) {
+ if (!result.has_value())
+ result = thisResult;
+ else
+ result = result->rangeUnion(thisResult);
+ };
+ for (int64_t i = minDim; i <= maxDim; ++i) {
+ int64_t length = shapedTy.getDimSize(i);
+
+ if (ShapedType::isDynamic(length))
+ joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
+ else
+ joinResult(ConstantIntRanges::constant(APInt(width, length)));
+ }
+ return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
+}
diff --git a/mlir/test/Dialect/MemRef/int-range-inference.mlir b/mlir/test/Dialect/MemRef/int-range-inference.mlir
new file mode 100644
index 00000000000000..e2aa487eaaa25b
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/int-range-inference.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @dim_const
+// CHECK: %[[ret:.+]] = arith.constant 3 : index
+// CHECK: return %[[ret]]
+func.func @dim_const(%m: memref<3x5xi32>) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = memref.dim %m, %c0 : memref<3x5xi32>
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static(%m: memref<3x5xi32>, %x: index) -> index {
+ %0 = memref.dim %m, %x : memref<3x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_dynamic(%m: memref<?x5xi32>) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = memref.dim %m, %c0 : memref<?x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_dynamic(%m: memref<?x5xi32>, %x: index) -> index {
+ %0 = memref.dim %m, %x : memref<?x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_some_omitting_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_some_omitting_dynamic(%m: memref<?x3x5xi32>, %x: index) -> index {
+ %c1 = arith.constant 1 : index
+ %0 = arith.maxsi %x, %c1 : index
+ %1 = memref.dim %m, %0 : memref<?x3x5xi32>
+ %2 = test.reflect_bounds %1 : index
+ return %2 : index
+}
diff --git a/mlir/test/Dialect/Tensor/int-range-inference.mlir b/mlir/test/Dialect/Tensor/int-range-inference.mlir
new file mode 100644
index 00000000000000..384ae781e0e330
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/int-range-inference.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @dim_const
+// CHECK: %[[ret:.+]] = arith.constant 3 : index
+// CHECK: return %[[ret]]
+func.func @dim_const(%m: tensor<3x5xi32>) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.dim %m, %c0 : tensor<3x5xi32>
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static(%m: tensor<3x5xi32>, %x: index) -> index {
+ %0 = tensor.dim %m, %x : tensor<3x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_dynamic(%m: tensor<?x5xi32>) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.dim %m, %c0 : tensor<?x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_dynamic(%m: tensor<?x5xi32>, %x: index) -> index {
+ %0 = tensor.dim %m, %x : tensor<?x5xi32>
+ %1 = test.reflect_bounds %0 : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_some_omitting_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_some_omitting_dynamic(%m: tensor<?x3x5xi32>, %x: index) -> index {
+ %c1 = arith.constant 1 : index
+ %0 = arith.maxsi %x, %c1 : index
+ %1 = tensor.dim %m, %0 : tensor<?x3x5xi32>
+ %2 = test.reflect_bounds %1 : index
+ return %2 : index
+}
|
Implement the integer range inference niterface for memref.dim and tetnor.dim using shared code. The inference will infer the `dim` of dynamic dimensions to [0, index_max] and take the union of all the dimensions that the `dim` argument could be validly referring to.
0436922
to
fdb82c1
Compare
Implement the integer range inference niterface for memref.dim and tetnor.dim using shared code. The inference will infer the
dim
of dynamic dimensions to [0, index_max] and take the union of all the dimensions that thedim
argument could be validly referring to.