Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IntRangeInference] Infer values for {memref,tensor}.dim #122945

Merged
merged 2 commits into from
Jan 30, 2025

Conversation

krzysz00
Copy link
Contributor

Implement the integer range inference niterface for memref.dim and tetnor.dim using shared code. The inference will infer the dim of dynamic dimensions to [0, index_max] and take the union of all the dimensions that the dim argument could be validly referring to.

@llvmbot
Copy link
Member

llvmbot commented Jan 14, 2025

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

Implement the integer range inference niterface for memref.dim and tetnor.dim using shared code. The inference will infer the dim of dynamic dimensions to [0, index_max] and take the union of all the dimensions that the dim argument could be validly referring to.


Full diff: https://github.com/llvm/llvm-project/pull/122945.diff

13 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRef.h (+1)
  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+4-2)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/Tensor.h (+1)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+3-1)
  • (modified) mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h (+8)
  • (modified) mlir/lib/Dialect/MemRef/IR/CMakeLists.txt (+2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+7)
  • (modified) mlir/lib/Dialect/Tensor/IR/CMakeLists.txt (+2)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+8)
  • (modified) mlir/lib/Interfaces/Utils/CMakeLists.txt (+1)
  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+44)
  • (added) mlir/test/Dialect/MemRef/int-range-inference.mlir (+61)
  • (added) mlir/test/Dialect/Tensor/int-range-inference.mlir (+61)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 72463dca715ca3..ac383ab46e7a50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -17,6 +17,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index a0d8d34f38237a..c3ee3968abc16d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -577,7 +578,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     MemRefsNormalizable,
     ConditionallySpeculatable, NoMemoryEffect,
-    ShapedDimOpInterface]> {
+    ShapedDimOpInterface,
+    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
   let summary = "dimension index operation";
   let description = [{
     The `dim` operation takes a memref and a dimension operand of type `index`.
@@ -1675,7 +1677,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     }]>,
 
     // Builder that infers the result layout map. The result shape must be
-    // specified. Otherwise, the op may be ambiguous. The output shape for 
+    // specified. Otherwise, the op may be ambiguous. The output shape for
     // the op will be inferred using the inferOutputShape() method.
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
                "ArrayRef<ReassociationIndices>":$reassociation)>,
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 0a21c9922b223b..bd96337a55407a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -18,6 +18,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 812ac209845020..38874513a4cc00 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -195,7 +196,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
 def Tensor_DimOp : Tensor_Op<"dim", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     ConditionallySpeculatable, NoMemoryEffect,
-    ShapedDimOpInterface]> {
+    ShapedDimOpInterface,
+    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
   let summary = "dimension index operation";
   let description = [{
     The `tensor.dim` operation takes a tensor and a dimension operand of type
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 3988a8826498a9..e46358ccfc46f7 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -20,6 +20,8 @@
 #include <optional>
 
 namespace mlir {
+class ShapedDimOpInterface;
+
 namespace intrange {
 /// Function that performs inference on an array of `ConstantIntRanges`,
 /// abstracted away here to permit writing the function that handles both
@@ -143,6 +145,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
                                  const ConstantIntRanges &lhs,
                                  const ConstantIntRanges &rhs);
 
+/// Returns the integer range for the result of a `ShapedDimOpInterface` given
+/// the optional inferred ranges for the `dimension` index `maybeDim`. When a
+/// dynamic dimension is encountered, returns [0, signed_max(type(result))].
+ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
+                                            const IntegerValueRange &maybeDim);
+
 } // namespace intrange
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 845914ebd107a2..734294bd014c6e 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -16,6 +16,8 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRControlFlowInterfaces
   MLIRDialect
   MLIRDialectUtils
+  MLIRInferIntRangeCommon
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9aae46a5c288dc..f0aee7a68e0bff 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -915,6 +916,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
   return Speculation::Speculatable;
 }
 
+void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
+                                          SetIntLatticeFn setResultRange) {
+  setResultRange(getResult(),
+                 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
+}
+
 /// Return a map with key being elements in `vals` and data being number of
 /// occurences of it. Use std::map, since the `vals` here are strides and the
 /// dynamic stride value is the same as the tombstone value for
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index cfdd3847761a49..5425615dac3932 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -26,6 +26,8 @@ add_mlir_dialect_library(MLIRTensorDialect
   MLIRDestinationStyleOpInterface
   MLIRDialectUtils
   MLIRIR
+  MLIRInferIntRangeCommon
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRParallelCombiningOpInterface
   MLIRShapedOpInterfaces
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 24a1d553153198..e0853cab60fb94 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -23,7 +23,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
@@ -782,6 +784,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
   return Speculation::Speculatable;
 }
 
+void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
+                                          SetIntLatticeFn setResultRange) {
+  setResultRange(getResult(),
+                 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
+}
+
 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   // All forms of folding require a known index.
   auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
index ece6c8e46ffea9..8c45f669974271 100644
--- a/mlir/lib/Interfaces/Utils/CMakeLists.txt
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRInferIntRangeCommon
     MLIRInferIntRangeInterfaceIncGen
 
     LINK_LIBS PUBLIC
+    MLIRShapedOpInterfaces
     MLIRInferIntRangeInterface
     MLIRIR
 )
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 1eab4139488bdd..2f47939df5a022 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/ShapedOpInterfaces.h"
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -725,3 +726,46 @@ std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
     return false;
   return std::nullopt;
 }
+
+//===----------------------------------------------------------------------===//
+// Shaped type dimension accessors / ShapedDimOpInterface
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
+                                          const IntegerValueRange &maybeDim) {
+  unsigned width =
+      ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
+  APInt zero = APInt::getZero(width);
+  APInt typeMax = APInt::getSignedMaxValue(width);
+
+  auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
+  if (!shapedTy.hasRank())
+    return ConstantIntRanges::fromSigned(zero, typeMax);
+
+  int64_t rank = shapedTy.getRank();
+  int64_t minDim = 0;
+  int64_t maxDim = rank - 1;
+  if (!maybeDim.isUninitialized()) {
+    const ConstantIntRanges &dim = maybeDim.getValue();
+    minDim = std::max(minDim, dim.smin().getSExtValue());
+    maxDim = std::min(maxDim, dim.smax().getSExtValue());
+  }
+
+  std::optional<ConstantIntRanges> result;
+  auto joinResult = [&](const ConstantIntRanges &thisResult) {
+    if (!result.has_value())
+      result = thisResult;
+    else
+      result = result->rangeUnion(thisResult);
+  };
+  for (int64_t i = minDim; i <= maxDim; ++i) {
+    int64_t length = shapedTy.getDimSize(i);
+
+    if (ShapedType::isDynamic(length))
+      joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
+    else
+      joinResult(ConstantIntRanges::constant(APInt(width, length)));
+  }
+  return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
+}
diff --git a/mlir/test/Dialect/MemRef/int-range-inference.mlir b/mlir/test/Dialect/MemRef/int-range-inference.mlir
new file mode 100644
index 00000000000000..e2aa487eaaa25b
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/int-range-inference.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @dim_const
+// CHECK: %[[ret:.+]] = arith.constant 3 : index
+// CHECK: return %[[ret]]
+func.func @dim_const(%m: memref<3x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.dim %m, %c0 : memref<3x5xi32>
+  return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static(%m: memref<3x5xi32>, %x: index) -> index {
+  %0 = memref.dim %m, %x : memref<3x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_dynamic(%m: memref<?x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.dim %m, %c0 : memref<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_dynamic(%m: memref<?x5xi32>, %x: index) -> index {
+  %0 = memref.dim %m, %x : memref<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_some_omitting_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_some_omitting_dynamic(%m: memref<?x3x5xi32>, %x: index) -> index {
+  %c1 = arith.constant 1 : index
+  %0 = arith.maxsi %x, %c1 : index
+  %1 = memref.dim %m, %0 : memref<?x3x5xi32>
+  %2 = test.reflect_bounds %1 : index
+  return %2 : index
+}
diff --git a/mlir/test/Dialect/Tensor/int-range-inference.mlir b/mlir/test/Dialect/Tensor/int-range-inference.mlir
new file mode 100644
index 00000000000000..384ae781e0e330
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/int-range-inference.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @dim_const
+// CHECK: %[[ret:.+]] = arith.constant 3 : index
+// CHECK: return %[[ret]]
+func.func @dim_const(%m: tensor<3x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %m, %c0 : tensor<3x5xi32>
+  return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static(%m: tensor<3x5xi32>, %x: index) -> index {
+  %0 = tensor.dim %m, %x : tensor<3x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_dynamic(%m: tensor<?x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %m, %c0 : tensor<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_dynamic(%m: tensor<?x5xi32>, %x: index) -> index {
+  %0 = tensor.dim %m, %x : tensor<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_some_omitting_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_some_omitting_dynamic(%m: tensor<?x3x5xi32>, %x: index) -> index {
+  %c1 = arith.constant 1 : index
+  %0 = arith.maxsi %x, %c1 : index
+  %1 = tensor.dim %m, %0 : tensor<?x3x5xi32>
+  %2 = test.reflect_bounds %1 : index
+  return %2 : index
+}

@llvmbot
Copy link
Member

llvmbot commented Jan 14, 2025

@llvm/pr-subscribers-mlir-tensor

Author: Krzysztof Drewniak (krzysz00)

Changes

Implement the integer range inference niterface for memref.dim and tetnor.dim using shared code. The inference will infer the dim of dynamic dimensions to [0, index_max] and take the union of all the dimensions that the dim argument could be validly referring to.


Full diff: https://github.com/llvm/llvm-project/pull/122945.diff

13 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRef.h (+1)
  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+4-2)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/Tensor.h (+1)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+3-1)
  • (modified) mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h (+8)
  • (modified) mlir/lib/Dialect/MemRef/IR/CMakeLists.txt (+2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+7)
  • (modified) mlir/lib/Dialect/Tensor/IR/CMakeLists.txt (+2)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+8)
  • (modified) mlir/lib/Interfaces/Utils/CMakeLists.txt (+1)
  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+44)
  • (added) mlir/test/Dialect/MemRef/int-range-inference.mlir (+61)
  • (added) mlir/test/Dialect/Tensor/int-range-inference.mlir (+61)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 72463dca715ca3..ac383ab46e7a50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -17,6 +17,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index a0d8d34f38237a..c3ee3968abc16d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -577,7 +578,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     MemRefsNormalizable,
     ConditionallySpeculatable, NoMemoryEffect,
-    ShapedDimOpInterface]> {
+    ShapedDimOpInterface,
+    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
   let summary = "dimension index operation";
   let description = [{
     The `dim` operation takes a memref and a dimension operand of type `index`.
@@ -1675,7 +1677,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     }]>,
 
     // Builder that infers the result layout map. The result shape must be
-    // specified. Otherwise, the op may be ambiguous. The output shape for 
+    // specified. Otherwise, the op may be ambiguous. The output shape for
     // the op will be inferred using the inferOutputShape() method.
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
                "ArrayRef<ReassociationIndices>":$reassociation)>,
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 0a21c9922b223b..bd96337a55407a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -18,6 +18,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 812ac209845020..38874513a4cc00 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -195,7 +196,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
 def Tensor_DimOp : Tensor_Op<"dim", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     ConditionallySpeculatable, NoMemoryEffect,
-    ShapedDimOpInterface]> {
+    ShapedDimOpInterface,
+    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
   let summary = "dimension index operation";
   let description = [{
     The `tensor.dim` operation takes a tensor and a dimension operand of type
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 3988a8826498a9..e46358ccfc46f7 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -20,6 +20,8 @@
 #include <optional>
 
 namespace mlir {
+class ShapedDimOpInterface;
+
 namespace intrange {
 /// Function that performs inference on an array of `ConstantIntRanges`,
 /// abstracted away here to permit writing the function that handles both
@@ -143,6 +145,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
                                  const ConstantIntRanges &lhs,
                                  const ConstantIntRanges &rhs);
 
+/// Returns the integer range for the result of a `ShapedDimOpInterface` given
+/// the optional inferred ranges for the `dimension` index `maybeDim`. When a
+/// dynamic dimension is encountered, returns [0, signed_max(type(result))].
+ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
+                                            const IntegerValueRange &maybeDim);
+
 } // namespace intrange
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 845914ebd107a2..734294bd014c6e 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -16,6 +16,8 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRControlFlowInterfaces
   MLIRDialect
   MLIRDialectUtils
+  MLIRInferIntRangeCommon
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9aae46a5c288dc..f0aee7a68e0bff 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -915,6 +916,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
   return Speculation::Speculatable;
 }
 
+void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
+                                          SetIntLatticeFn setResultRange) {
+  setResultRange(getResult(),
+                 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
+}
+
 /// Return a map with key being elements in `vals` and data being number of
 /// occurences of it. Use std::map, since the `vals` here are strides and the
 /// dynamic stride value is the same as the tombstone value for
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index cfdd3847761a49..5425615dac3932 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -26,6 +26,8 @@ add_mlir_dialect_library(MLIRTensorDialect
   MLIRDestinationStyleOpInterface
   MLIRDialectUtils
   MLIRIR
+  MLIRInferIntRangeCommon
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRParallelCombiningOpInterface
   MLIRShapedOpInterfaces
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 24a1d553153198..e0853cab60fb94 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -23,7 +23,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
@@ -782,6 +784,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
   return Speculation::Speculatable;
 }
 
+void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
+                                          SetIntLatticeFn setResultRange) {
+  setResultRange(getResult(),
+                 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
+}
+
 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   // All forms of folding require a known index.
   auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
index ece6c8e46ffea9..8c45f669974271 100644
--- a/mlir/lib/Interfaces/Utils/CMakeLists.txt
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRInferIntRangeCommon
     MLIRInferIntRangeInterfaceIncGen
 
     LINK_LIBS PUBLIC
+    MLIRShapedOpInterfaces
     MLIRInferIntRangeInterface
     MLIRIR
 )
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 1eab4139488bdd..2f47939df5a022 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/ShapedOpInterfaces.h"
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -725,3 +726,46 @@ std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
     return false;
   return std::nullopt;
 }
+
+//===----------------------------------------------------------------------===//
+// Shaped type dimension accessors / ShapedDimOpInterface
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
+                                          const IntegerValueRange &maybeDim) {
+  unsigned width =
+      ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
+  APInt zero = APInt::getZero(width);
+  APInt typeMax = APInt::getSignedMaxValue(width);
+
+  auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
+  if (!shapedTy.hasRank())
+    return ConstantIntRanges::fromSigned(zero, typeMax);
+
+  int64_t rank = shapedTy.getRank();
+  int64_t minDim = 0;
+  int64_t maxDim = rank - 1;
+  if (!maybeDim.isUninitialized()) {
+    const ConstantIntRanges &dim = maybeDim.getValue();
+    minDim = std::max(minDim, dim.smin().getSExtValue());
+    maxDim = std::min(maxDim, dim.smax().getSExtValue());
+  }
+
+  std::optional<ConstantIntRanges> result;
+  auto joinResult = [&](const ConstantIntRanges &thisResult) {
+    if (!result.has_value())
+      result = thisResult;
+    else
+      result = result->rangeUnion(thisResult);
+  };
+  for (int64_t i = minDim; i <= maxDim; ++i) {
+    int64_t length = shapedTy.getDimSize(i);
+
+    if (ShapedType::isDynamic(length))
+      joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
+    else
+      joinResult(ConstantIntRanges::constant(APInt(width, length)));
+  }
+  return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
+}
diff --git a/mlir/test/Dialect/MemRef/int-range-inference.mlir b/mlir/test/Dialect/MemRef/int-range-inference.mlir
new file mode 100644
index 00000000000000..e2aa487eaaa25b
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/int-range-inference.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @dim_const
+// CHECK: %[[ret:.+]] = arith.constant 3 : index
+// CHECK: return %[[ret]]
+func.func @dim_const(%m: memref<3x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.dim %m, %c0 : memref<3x5xi32>
+  return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static(%m: memref<3x5xi32>, %x: index) -> index {
+  %0 = memref.dim %m, %x : memref<3x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_dynamic(%m: memref<?x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.dim %m, %c0 : memref<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_dynamic(%m: memref<?x5xi32>, %x: index) -> index {
+  %0 = memref.dim %m, %x : memref<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_some_omitting_dynamic
+// CHECK: %[[op:.+]] = memref.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_some_omitting_dynamic(%m: memref<?x3x5xi32>, %x: index) -> index {
+  %c1 = arith.constant 1 : index
+  %0 = arith.maxsi %x, %c1 : index
+  %1 = memref.dim %m, %0 : memref<?x3x5xi32>
+  %2 = test.reflect_bounds %1 : index
+  return %2 : index
+}
diff --git a/mlir/test/Dialect/Tensor/int-range-inference.mlir b/mlir/test/Dialect/Tensor/int-range-inference.mlir
new file mode 100644
index 00000000000000..384ae781e0e330
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/int-range-inference.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @dim_const
+// CHECK: %[[ret:.+]] = arith.constant 3 : index
+// CHECK: return %[[ret]]
+func.func @dim_const(%m: tensor<3x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %m, %c0 : tensor<3x5xi32>
+  return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static(%m: tensor<3x5xi32>, %x: index) -> index {
+  %0 = tensor.dim %m, %x : tensor<3x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_dynamic(%m: tensor<?x5xi32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %m, %c0 : tensor<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_any_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_any_dynamic(%m: tensor<?x5xi32>, %x: index) -> index {
+  %0 = tensor.dim %m, %x : tensor<?x5xi32>
+  %1 = test.reflect_bounds %0 : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @dim_some_omitting_dynamic
+// CHECK: %[[op:.+]] = tensor.dim
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: return %[[ret]]
+func.func @dim_some_omitting_dynamic(%m: tensor<?x3x5xi32>, %x: index) -> index {
+  %c1 = arith.constant 1 : index
+  %0 = arith.maxsi %x, %c1 : index
+  %1 = tensor.dim %m, %0 : tensor<?x3x5xi32>
+  %2 = test.reflect_bounds %1 : index
+  return %2 : index
+}

Implement the integer range inference niterface for memref.dim and
tetnor.dim using shared code. The inference will infer the `dim` of
dynamic dimensions to [0, index_max] and take the union of all the
dimensions that the `dim` argument could be validly referring to.
@krzysz00 krzysz00 force-pushed the shaped-dim-like-int-range-inference branch from 0436922 to fdb82c1 Compare January 18, 2025 04:09
@krzysz00 krzysz00 merged commit cdc09a1 into llvm:main Jan 30, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants