-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Handle arith.const expr in dispatchIndexOpFoldResult func #122432
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (rutkoor) ChangesThis PR addresses the handling of arith.constant expressions in the dispatchIndexOpFoldResult helper function. Previously, the helper function dispatched an OpFoldResult into staticVec only if it was an IntegerAttr. The changes in this PR now enable the evaluation of arith.constant expressions, extraction of the integer value, and dispatch into staticVec. Full diff: https://github.com/llvm/llvm-project/pull/122432.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5c8f6ded39ba4e..7ad4c982af2aae 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
@@ -54,6 +55,18 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
staticVec.push_back(apInt.getSExtValue());
return;
}
+
+ Operation *definingOp = v.getDefiningOp();
+ if (definingOp) {
+ // Check if definingOp is an arith.constant
+ if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
+ if (auto intAttr = mlir::dyn_cast<IntegerAttr>(constantOp.getValue())) {
+ staticVec.push_back(intAttr.getValue().getSExtValue());
+ return;
+ }
+ }
+ }
+
dynamicVec.push_back(v);
staticVec.push_back(ShapedType::kDynamic);
}
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index cf6b12852bcd39..15bc9b0435f6e6 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -20,6 +20,26 @@ func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1:
// -----
+func.func @bubble_parallel_reshapes2(%arg0: tensor<?x2x2x6xf32>, %s0: index, %s1: index) -> tensor<?x4x2x3xf32> {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+ output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
+ return %expand : tensor<?x4x2x3xf32>
+}
+// CHECK: func @bubble_parallel_reshapes2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x2x2x6xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME: output_shape [%[[S0]], 2, 2, %[[C2]], %[[C3]]] : tensor<?x2x2x6xf32> into tensor<?x2x2x2x3xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x2x2x2x3xf32> into tensor<?x4x2x3xf32>
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
implements the functionality that you are looking for. You can use it with in combination with the existing dispatchIndexOpFoldResults
.
@@ -7,6 +7,7 @@ | |||
//===----------------------------------------------------------------------===// | |||
|
|||
#include "mlir/Dialect/Utils/StaticValueUtils.h" | |||
#include "mlir/Dialect/Arith/IR/Arith.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should not depend on any dialect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thanks a lot for the suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I mean you can use dispatchIndexOpFoldResult(getAsOpFoldResult(v))
wherever you need it. I wouldn't call getAsOpFoldResult
from dispatchIndexOpFoldResult
because it does not fit with the name of the function. This function is just a switch that populates two vectors, it's not meant to analyze any IR.
Why do you need this functionality?
Change-Id: I15280932f88d8ff638f5d0f964a1c03ce7a7881a
%c3 = arith.constant 3 : index | ||
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32> | ||
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]] | ||
output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not that familiar with this op anymore but I expected output_shape [%s0, 4, 2, 3]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I'm surprised that the verifier allows this op. @MaheshRavishankar to clarify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, probably missing a canonicalizer here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expected the output_shape
to match the result type. I.e., an output dim must be static iff the respective dim is static in the result type. Is that not the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The verifier doesnt enforce it (would be wrong to do so... its not necessarily wrong IR), but we could convert the dynamic values in output_shape
to static values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with other tensor dialect ops, I would recommend to make the verifier stricter. See discussion here.
@@ -54,6 +54,14 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr, | |||
staticVec.push_back(apInt.getSExtValue()); | |||
return; | |||
} | |||
|
|||
OpFoldResult result = getAsOpFoldResult(v); | |||
if (auto attr = result.dyn_cast<Attribute>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could dyn_cast the OpFoldResult directly as IntegerAttr.
saves forced cast on next line.
some thing like if (auto iattr = dyn_cast<IntegerAttr>(result)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is maybe OK... but seems like it is trying to account for some other issue. Is the issue the expand_shape
op? Then you are probably missing a canonicalizer that can fold the constant shapes into the output shape.
%c3 = arith.constant 3 : index | ||
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32> | ||
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]] | ||
output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, probably missing a canonicalizer here.
Hi @matthias-springer @MaheshRavishankar Given my six months of experience with MLIR, I am uncertain whether these changes should be made in the MLIR repository or within the third-party library. The builder is currently failing, resulting in 16 test cases failing. I have managed to resolve 15 of these test cases by adjusting the static shape and stride information. I believe that this change will not only benefit the third-party library but also improve some MLIR tests by enabling vector code generation through static stride information. Additionally, I am aware of issues with the verifier, which is legalizing certain operations incorrectly. I would appreciate feedback on whether it is advisable to proceed with this PR. Your guidance will be invaluable. Thank you. cc: @javedabsar1 |
@rutkoor Can you post the C++ code that builds the As the name suggests, |
This is the code,
It is invoking decomposeMixedValues function which is trying to extract integer from the expressions.
Instead of invoking decomposeMixedValues function, it should invoke the dispatchIndexOpFoldResult. |
Can you also post the call site of |
What I am wondering: Where is this example op coming from?
I'd like to understand why the |
Without the changes from this PR, this test case is invalid, it will throw below error,
The |
OK, I understand now what you are trying to do. In my opinion, the input IR is invalid. We should make the verifier stricter to reject such ops. Invalid: %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32> Valid: %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, 4, 2, 3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32> The valid IR works with As for the reason why the first one should be invalid and the second one is valid, take a look a this discussion: https://discourse.llvm.org/t/tensor-ops-with-dynamic-sizes-which-behaviour-is-more-correct/82612. We were discussing the same issue in the context of |
Thanks a lot @matthias-springer for providing the details and clarification. Closing the PR now. cc: @javedabsar1 |
We can't have this be a verifier issue. Let me respond on the discord thread |
Ok, I misread the IR... I agree with what Mathias says here. The following are valid Valid: %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, 4, 2, 3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32> or %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x?x?xf32> but this should be invalid
That is inconsistent... |
This PR addresses the handling of arith.constant expressions in the dispatchIndexOpFoldResult helper function. Previously, the helper function dispatched an OpFoldResult into staticVec, only if it was an IntegerAttr. The changes in this PR now enable the evaluation of arith.constant expressions, extraction of the integer value, and dispatch into staticVec.