Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Handle arith.const expr in dispatchIndexOpFoldResult func #122432

Closed
wants to merge 2 commits into from

Conversation

rutkoor
Copy link
Contributor

@rutkoor rutkoor commented Jan 10, 2025

This PR addresses the handling of arith.constant expressions in the dispatchIndexOpFoldResult helper function. Previously, the helper function dispatched an OpFoldResult into staticVec, only if it was an IntegerAttr. The changes in this PR now enable the evaluation of arith.constant expressions, extraction of the integer value, and dispatch into staticVec.

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (rutkoor)

Changes

This PR addresses the handling of arith.constant expressions in the dispatchIndexOpFoldResult helper function. Previously, the helper function dispatched an OpFoldResult into staticVec only if it was an IntegerAttr. The changes in this PR now enable the evaluation of arith.constant expressions, extraction of the integer value, and dispatch into staticVec.


Full diff: https://github.com/llvm/llvm-project/pull/122432.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+13)
  • (modified) mlir/test/Dialect/Tensor/bubble-reshapes.mlir (+20)
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5c8f6ded39ba4e..7ad4c982af2aae 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/APSInt.h"
@@ -54,6 +55,18 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
     staticVec.push_back(apInt.getSExtValue());
     return;
   }
+
+  Operation *definingOp = v.getDefiningOp();
+  if (definingOp) {
+    // Check if definingOp is an arith.constant
+    if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
+      if (auto intAttr = mlir::dyn_cast<IntegerAttr>(constantOp.getValue())) {
+        staticVec.push_back(intAttr.getValue().getSExtValue());
+        return;
+      }
+    }
+  }
+
   dynamicVec.push_back(v);
   staticVec.push_back(ShapedType::kDynamic);
 }
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index cf6b12852bcd39..15bc9b0435f6e6 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -20,6 +20,26 @@ func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1:
 
 // -----
 
+func.func @bubble_parallel_reshapes2(%arg0: tensor<?x2x2x6xf32>, %s0: index, %s1: index) -> tensor<?x4x2x3xf32> {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+              output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
+  return %expand : tensor<?x4x2x3xf32>
+}
+//      CHECK: func @bubble_parallel_reshapes2
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x2x2x6xf32>
+// CHECK-SAME:   %[[S0:.+]]: index, %[[S1:.+]]: index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME:       output_shape [%[[S0]], 2, 2, %[[C2]], %[[C3]]] : tensor<?x2x2x6xf32> into tensor<?x2x2x2x3xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x2x2x2x3xf32> into tensor<?x4x2x3xf32>
+//      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
 func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
   %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
   %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]

@rutkoor
Copy link
Contributor Author

rutkoor commented Jan 10, 2025

cc: @javedabsar , @MaheshRavishankar

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values); implements the functionality that you are looking for. You can use it with in combination with the existing dispatchIndexOpFoldResults.

@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not depend on any dialect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks a lot for the suggestion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I mean you can use dispatchIndexOpFoldResult(getAsOpFoldResult(v)) wherever you need it. I wouldn't call getAsOpFoldResult from dispatchIndexOpFoldResult because it does not fit with the name of the function. This function is just a switch that populates two vectors, it's not meant to analyze any IR.

Why do you need this functionality?

Change-Id: I15280932f88d8ff638f5d0f964a1c03ce7a7881a
%c3 = arith.constant 3 : index
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not that familiar with this op anymore but I expected output_shape [%s0, 4, 2, 3].

Copy link
Member

@matthias-springer matthias-springer Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm surprised that the verifier allows this op. @MaheshRavishankar to clarify.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, probably missing a canonicalizer here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expected the output_shape to match the result type. I.e., an output dim must be static iff the respective dim is static in the result type. Is that not the case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The verifier doesnt enforce it (would be wrong to do so... its not necessarily wrong IR), but we could convert the dynamic values in output_shape to static values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with other tensor dialect ops, I would recommend to make the verifier stricter. See discussion here.

@@ -54,6 +54,14 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
staticVec.push_back(apInt.getSExtValue());
return;
}

OpFoldResult result = getAsOpFoldResult(v);
if (auto attr = result.dyn_cast<Attribute>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could dyn_cast the OpFoldResult directly as IntegerAttr.
saves forced cast on next line.
some thing like if (auto iattr = dyn_cast<IntegerAttr>(result)

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is maybe OK... but seems like it is trying to account for some other issue. Is the issue the expand_shape op? Then you are probably missing a canonicalizer that can fold the constant shapes into the output shape.

%c3 = arith.constant 3 : index
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, probably missing a canonicalizer here.

@rutkoor
Copy link
Contributor Author

rutkoor commented Jan 14, 2025

Hi @matthias-springer @MaheshRavishankar
Thanks for the comments. Just to give you some background, I have encountered an issue with a test case from a third-party library when lowering to an MLIR module. Specifically, the stride information is not being populated when creating the memref.reinterpret_cast instruction. After investigation, I have traced the problem to the dispatchIndexOpFoldResult function, which is invoked from the third-party library.

Given my six months of experience with MLIR, I am uncertain whether these changes should be made in the MLIR repository or within the third-party library. The builder is currently failing, resulting in 16 test cases failing. I have managed to resolve 15 of these test cases by adjusting the static shape and stride information.

I believe that this change will not only benefit the third-party library but also improve some MLIR tests by enabling vector code generation through static stride information. Additionally, I am aware of issues with the verifier, which is legalizing certain operations incorrectly.

I would appreciate feedback on whether it is advisable to proceed with this PR. Your guidance will be invaluable.

Thank you.

cc: @javedabsar1

@matthias-springer
Copy link
Member

@rutkoor Can you post the C++ code that builds the tensor.expand_shape op? I suspect that something is wrong there.

As the name suggests, dispatchIndexOpFoldResult should just dispatch based on the type of the OpFoldResult. I.e., it's just a switch-case statement and not meant to look at the IR that it's operating on.

@rutkoor
Copy link
Contributor Author

rutkoor commented Jan 14, 2025

@rutkoor Can you post the C++ code that builds the tensor.expand_shape op? I suspect that something is wrong there.

As the name suggests, dispatchIndexOpFoldResult should just dispatch based on the type of the OpFoldResult. I.e., it's just a switch-case statement and not meant to look at the IR that it's operating on.

This is the code,

void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
                          Type resultType, Value src,
                          ArrayRef<ReassociationIndices> reassociation,
                          ArrayRef<OpFoldResult> outputShape) {
  auto [staticOutputShape, dynamicOutputShape] =
      decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
  build(builder, result, cast<RankedTensorType>(resultType), src,
        getReassociationIndicesAttribute(builder, reassociation),
        dynamicOutputShape, staticOutputShape);
}

It is invoking decomposeMixedValues function which is trying to extract integer from the expressions.

/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
std::pair<SmallVector<int64_t>, SmallVector<Value>>
decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
  SmallVector<int64_t> staticValues;
  SmallVector<Value> dynamicValues;
  for (const auto &it : mixedValues) {
    if (auto attr = dyn_cast<Attribute>(it)) {
      staticValues.push_back(cast<IntegerAttr>(attr).getInt());
    } else {
      staticValues.push_back(ShapedType::kDynamic);
      dynamicValues.push_back(cast<Value>(it));
    }
  }
  return {staticValues, dynamicValues};
}

Instead of invoking decomposeMixedValues function, it should invoke the dispatchIndexOpFoldResult.

@matthias-springer
Copy link
Member

Can you also post the call site of ExpandShapeOp::build? How are resultType and outputShape computed?

@matthias-springer
Copy link
Member

What I am wondering: Where is this example op coming from?

  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
              output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>

I'd like to understand why the output_shape is not static.

@rutkoor
Copy link
Contributor Author

rutkoor commented Jan 14, 2025

What I am wondering: Where is this example op coming from?

  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
              output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>

I'd like to understand why the output_shape is not static.

Without the changes from this PR, this test case is invalid, it will throw below error,

within split at mlir/test/Dialect/Tensor/bubble-reshapes.mlir:21 offset :7:13: error: 'tensor.expand_shape' op expected dimension 3 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic
  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
            ^
within split at mlir/test/Dialect/Tensor/bubble-reshapes.mlir:21 offset :7:13: note: see current operation: %2 = "tensor.expand_shape"(%arg0, %arg1, %0, %1) <{reassociation = [[0], [1], [2], [3, 4]], static_output_shape = array<i64: -9223372036854775808, 2, 2, -9223372036854775808, -9223372036854775808>}> : (tensor<?x2x2x6xf32>, index, index, index) -> tensor<?x2x2x?x?xf32>

BubbleUpExpandThroughParallelCollapse patternRewriter is creating a <tensor::ExpandShapeOp> which is where we pass resultType and other arguments. Below is the code from BubbleUpExpandThroughParallelCollapse.

    // Swap reshape order.
    SmallVector<Value> dynamicSizes;
    SmallVector<int64_t> staticSizes;
    dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
    auto expandResultType = expandOp.getResultType().clone(staticSizes);
    auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
        loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
        newExpandSizes);
    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
        expandOp, newExpand.getResult(), newCollapseReInds);

The output_shape is part of newExpandSizes which is being passed to dispatchIndexOpFoldResults function.

@matthias-springer
Copy link
Member

OK, I understand now what you are trying to do. In my opinion, the input IR is invalid. We should make the verifier stricter to reject such ops.

Invalid:

  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
              output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>

Valid:

  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
              output_shape [%s0, 4, 2, 3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>

The valid IR works with -test-tensor-transform-patterns=test-expand-shape-bubbling.

As for the reason why the first one should be invalid and the second one is valid, take a look a this discussion: https://discourse.llvm.org/t/tensor-ops-with-dynamic-sizes-which-behaviour-is-more-correct/82612. We were discussing the same issue in the context of tensor.pack and came to that conclusion.

@rutkoor
Copy link
Contributor Author

rutkoor commented Jan 15, 2025

Thanks a lot @matthias-springer for providing the details and clarification. Closing the PR now.

cc: @javedabsar1

@rutkoor rutkoor closed this Jan 15, 2025
@rutkoor rutkoor deleted the cg1 branch January 15, 2025 09:59
@MaheshRavishankar
Copy link
Contributor

We can't have this be a verifier issue. Let me respond on the discord thread

@MaheshRavishankar
Copy link
Contributor

OK, I understand now what you are trying to do. In my opinion, the input IR is invalid. We should make the verifier stricter to reject such ops.

Invalid:

  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
              output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>

Valid:

  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
              output_shape [%s0, 4, 2, 3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>

The valid IR works with -test-tensor-transform-patterns=test-expand-shape-bubbling.

As for the reason why the first one should be invalid and the second one is valid, take a look a this discussion: https://discourse.llvm.org/t/tensor-ops-with-dynamic-sizes-which-behaviour-is-more-correct/82612. We were discussing the same issue in the context of tensor.pack and came to that conclusion.

Ok, I misread the IR... I agree with what Mathias says here. The following are valid

Valid:

  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
              output_shape [%s0, 4, 2, 3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>

or

  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
              output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x?x?xf32>

but this should be invalid

  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
              output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>

That is inconsistent...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants