Skip to content

Commit

Permalink
[mlir][IntRangeInference] Infer values for {memref,tensor}.dim (#122945)
Browse files Browse the repository at this point in the history
Implement the integer range inference niterface for memref.dim and
tetnor.dim using shared code. The inference will infer the `dim` of
dynamic dimensions to [0, index_max] and take the union of all the
dimensions that the `dim` argument could be validly referring to.
  • Loading branch information
krzysz00 authored Jan 30, 2025
1 parent de7438e commit cdc09a1
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 3 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
Expand Down
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
Expand Down Expand Up @@ -577,7 +578,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
MemRefsNormalizable,
ConditionallySpeculatable, NoMemoryEffect,
ShapedDimOpInterface]> {
ShapedDimOpInterface,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
let summary = "dimension index operation";
let description = [{
The `dim` operation takes a memref and a dimension operand of type `index`.
Expand Down Expand Up @@ -1675,7 +1677,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}]>,

// Builder that infers the result layout map. The result shape must be
// specified. Otherwise, the op may be ambiguous. The output shape for
// specified. Otherwise, the op may be ambiguous. The output shape for
// the op will be inferred using the inferOutputShape() method.
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation)>,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
Expand Down Expand Up @@ -197,7 +198,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
def Tensor_DimOp : Tensor_Op<"dim", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
ConditionallySpeculatable, NoMemoryEffect,
ShapedDimOpInterface]> {
ShapedDimOpInterface,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
let summary = "dimension index operation";
let description = [{
The `tensor.dim` operation takes a tensor and a dimension operand of type
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <optional>

namespace mlir {
class ShapedDimOpInterface;

namespace intrange {
/// Function that performs inference on an array of `ConstantIntRanges`,
/// abstracted away here to permit writing the function that handles both
Expand Down Expand Up @@ -143,6 +145,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
const ConstantIntRanges &lhs,
const ConstantIntRanges &rhs);

/// Returns the integer range for the result of a `ShapedDimOpInterface` given
/// the optional inferred ranges for the `dimension` index `maybeDim`. When a
/// dynamic dimension is encountered, returns [0, signed_max(type(result))].
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
const IntegerValueRange &maybeDim);

} // namespace intrange
} // namespace mlir

Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRControlFlowInterfaces
MLIRDialect
MLIRDialectUtils
MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemorySlotInterfaces
Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
Expand Down Expand Up @@ -915,6 +916,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
return Speculation::Speculatable;
}

void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
SetIntLatticeFn setResultRange) {
setResultRange(getResult(),
intrange::inferShapedDimOpInterface(*this, argRanges[1]));
}

/// Return a map with key being elements in `vals` and data being number of
/// occurences of it. Use std::map, since the `vals` here are strides and the
/// dynamic stride value is the same as the tombstone value for
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ add_mlir_dialect_library(MLIRTensorDialect
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRIR
MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRParallelCombiningOpInterface
MLIRShapedOpInterfaces
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -782,6 +784,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
return Speculation::Speculatable;
}

void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
SetIntLatticeFn setResultRange) {
setResultRange(getResult(),
intrange::inferShapedDimOpInterface(*this, argRanges[1]));
}

OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Interfaces/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_library(MLIRInferIntRangeCommon
MLIRInferIntRangeInterfaceIncGen

LINK_LIBS PUBLIC
MLIRShapedOpInterfaces
MLIRInferIntRangeInterface
MLIRIR
)
44 changes: 44 additions & 0 deletions mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"

#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -725,3 +726,46 @@ std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
return false;
return std::nullopt;
}

//===----------------------------------------------------------------------===//
// Shaped type dimension accessors / ShapedDimOpInterface
//===----------------------------------------------------------------------===//

ConstantIntRanges
mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
const IntegerValueRange &maybeDim) {
unsigned width =
ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
APInt zero = APInt::getZero(width);
APInt typeMax = APInt::getSignedMaxValue(width);

auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
if (!shapedTy.hasRank())
return ConstantIntRanges::fromSigned(zero, typeMax);

int64_t rank = shapedTy.getRank();
int64_t minDim = 0;
int64_t maxDim = rank - 1;
if (!maybeDim.isUninitialized()) {
const ConstantIntRanges &dim = maybeDim.getValue();
minDim = std::max(minDim, dim.smin().getSExtValue());
maxDim = std::min(maxDim, dim.smax().getSExtValue());
}

std::optional<ConstantIntRanges> result;
auto joinResult = [&](const ConstantIntRanges &thisResult) {
if (!result.has_value())
result = thisResult;
else
result = result->rangeUnion(thisResult);
};
for (int64_t i = minDim; i <= maxDim; ++i) {
int64_t length = shapedTy.getDimSize(i);

if (ShapedType::isDynamic(length))
joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
else
joinResult(ConstantIntRanges::constant(APInt(width, length)));
}
return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
}
74 changes: 74 additions & 0 deletions mlir/test/Dialect/MemRef/int-range-inference.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s

// CHECK-LABEL: @dim_const
// CHECK: %[[ret:.+]] = arith.constant 3 : index
// CHECK: return %[[ret]]
func.func @dim_const(%m: memref<3x5xi32>) -> index {
%c0 = arith.constant 0 : index
%0 = memref.dim %m, %c0 : memref<3x5xi32>
return %0 : index
}

// -----

// CHECK-LABEL: @dim_any_static
// CHECK: %[[op:.+]] = memref.dim
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
// CHECK: return %[[ret]]
func.func @dim_any_static(%m: memref<3x5xi32>, %x: index) -> index {
%0 = memref.dim %m, %x : memref<3x5xi32>
%1 = test.reflect_bounds %0 : index
return %1 : index
}

// -----

// CHECK-LABEL: @dim_dynamic
// CHECK: %[[op:.+]] = memref.dim
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
// CHECK: return %[[ret]]
func.func @dim_dynamic(%m: memref<?x5xi32>) -> index {
%c0 = arith.constant 0 : index
%0 = memref.dim %m, %c0 : memref<?x5xi32>
%1 = test.reflect_bounds %0 : index
return %1 : index
}

// -----

// CHECK-LABEL: @dim_any_dynamic
// CHECK: %[[op:.+]] = memref.dim
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
// CHECK: return %[[ret]]
func.func @dim_any_dynamic(%m: memref<?x5xi32>, %x: index) -> index {
%0 = memref.dim %m, %x : memref<?x5xi32>
%1 = test.reflect_bounds %0 : index
return %1 : index
}

// -----

// CHECK-LABEL: @dim_some_omitting_dynamic
// CHECK: %[[op:.+]] = memref.dim
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
// CHECK: return %[[ret]]
func.func @dim_some_omitting_dynamic(%m: memref<?x3x5xi32>, %x: index) -> index {
%c1 = arith.constant 1 : index
%0 = arith.maxsi %x, %c1 : index
%1 = memref.dim %m, %0 : memref<?x3x5xi32>
%2 = test.reflect_bounds %1 : index
return %2 : index
}

// -----

// CHECK-LABEL: @dim_unranked
// CHECK: %[[op:.+]] = memref.dim
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
// CHECK: return %[[ret]]
func.func @dim_unranked(%m: memref<*xi32>) -> index {
%c0 = arith.constant 0 : index
%0 = memref.dim %m, %c0 : memref<*xi32>
%1 = test.reflect_bounds %0 : index
return %1 : index
}
74 changes: 74 additions & 0 deletions mlir/test/Dialect/Tensor/int-range-inference.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s

// CHECK-LABEL: @dim_const
// CHECK: %[[ret:.+]] = arith.constant 3 : index
// CHECK: return %[[ret]]
func.func @dim_const(%t: tensor<3x5xi32>) -> index {
%c0 = arith.constant 0 : index
%0 = tensor.dim %t, %c0 : tensor<3x5xi32>
return %0 : index
}

// -----

// CHECK-LABEL: @dim_any_static
// CHECK: %[[op:.+]] = tensor.dim
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
// CHECK: return %[[ret]]
func.func @dim_any_static(%t: tensor<3x5xi32>, %x: index) -> index {
%0 = tensor.dim %t, %x : tensor<3x5xi32>
%1 = test.reflect_bounds %0 : index
return %1 : index
}

// -----

// CHECK-LABEL: @dim_dynamic
// CHECK: %[[op:.+]] = tensor.dim
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
// CHECK: return %[[ret]]
func.func @dim_dynamic(%t: tensor<?x5xi32>) -> index {
%c0 = arith.constant 0 : index
%0 = tensor.dim %t, %c0 : tensor<?x5xi32>
%1 = test.reflect_bounds %0 : index
return %1 : index
}

// -----

// CHECK-LABEL: @dim_any_dynamic
// CHECK: %[[op:.+]] = tensor.dim
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
// CHECK: return %[[ret]]
func.func @dim_any_dynamic(%t: tensor<?x5xi32>, %x: index) -> index {
%0 = tensor.dim %t, %x : tensor<?x5xi32>
%1 = test.reflect_bounds %0 : index
return %1 : index
}

// -----

// CHECK-LABEL: @dim_some_omitting_dynamic
// CHECK: %[[op:.+]] = tensor.dim
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
// CHECK: return %[[ret]]
func.func @dim_some_omitting_dynamic(%t: tensor<?x3x5xi32>, %x: index) -> index {
%c1 = arith.constant 1 : index
%0 = arith.maxsi %x, %c1 : index
%1 = tensor.dim %t, %0 : tensor<?x3x5xi32>
%2 = test.reflect_bounds %1 : index
return %2 : index
}

// -----

// CHECK-LABEL: @dim_unranked
// CHECK: %[[op:.+]] = tensor.dim
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
// CHECK: return %[[ret]]
func.func @dim_unranked(%t: tensor<*xi32>) -> index {
%c0 = arith.constant 0 : index
%0 = tensor.dim %t, %c0 : tensor<*xi32>
%1 = test.reflect_bounds %0 : index
return %1 : index
}

0 comments on commit cdc09a1

Please sign in to comment.