-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Add support for 0-d shapes in extract-shape_cast folder #116650
Conversation
@llvm/pr-subscribers-mlir-vector Author: Kunwar Grover (Groverkss) ChangesThe extract <-> shape cast folder was conservatively asserting on 0-d vectors. This pr fixes this. This pr also adds more tests for 0d cases and updates related tests to better reflect what they test. Full diff: https://github.com/llvm/llvm-project/pull/116650.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..d7e53e2e14dcfc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1756,11 +1756,6 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
if (!shapeCastOp)
return Value();
- // 0-D vectors not supported.
- assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
- if (hasZeroDimVectors(shapeCastOp))
- return Value();
-
// Get the nth dimension size starting from lowest dimension.
auto getDimReverse = [](VectorType type, int64_t n) {
return type.getShape().take_back(n + 1).front();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5ae769090dac66..04518a56c3dd20 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -782,23 +782,23 @@ func.func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
// -----
-// CHECK-LABEL: fold_extract_shapecast_negative
-// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
-// CHECK: return %[[R]] : vector<4x2xf32>
-func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
- %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
- %r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
- return %r : vector<4x2xf32>
+// CHECK-LABEL: fold_extract_shapecast_0d_result
+// CHECK-SAME: %[[IN:.*]]: vector<1x1x1xf32>
+// CHECK: %[[R:.*]] = vector.extract %[[IN]][0, 0, 0] : f32 from vector<1x1x1xf32>
+// CHECK: return %[[R]] : f32
+func.func @fold_extract_shapecast_0d_result(%arg0 : vector<1x1x1xf32>) -> f32 {
+ %0 = vector.shape_cast %arg0 : vector<1x1x1xf32> to vector<f32>
+ %r = vector.extract %0[] : f32 from vector<f32>
+ return %r : f32
}
// -----
-// CHECK-LABEL: dont_fold_0d_extract_shapecast
-// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<f32> to vector<1xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+// CHECK-LABEL: fold_extract_shapecast_0d_source
+// CHECK-SAME: %[[IN:.*]]: vector<f32>
+// CHECK: %[[R:.*]] = vector.extract %[[IN]][] : f32 from vector<f32>
// CHECK: return %[[R]] : f32
-func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
+func.func @fold_extract_shapecast_0d_source(%arg0 : vector<f32>) -> f32 {
%0 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
%r = vector.extract %0[0] : f32 from vector<1xf32>
return %r : f32
@@ -806,11 +806,23 @@ func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
// -----
-// CHECK-LABEL: fold_extract_shapecast_to_shapecast
+// CHECK-LABEL: fold_extract_shapecast_negative
+// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
+// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
+// CHECK: return %[[R]] : vector<4x2xf32>
+func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
+ %r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
+ return %r : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: canonicalize_extract_shapecast_to_shapecast
// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
// CHECK: return %[[R]]
-func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
+func.func @canonicalize_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
%0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
%r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32>
return %r : vector<12xf32>
|
@llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesThe extract <-> shape cast folder was conservatively asserting on 0-d vectors. This pr fixes this. This pr also adds more tests for 0d cases and updates related tests to better reflect what they test. Full diff: https://github.com/llvm/llvm-project/pull/116650.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..d7e53e2e14dcfc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1756,11 +1756,6 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
if (!shapeCastOp)
return Value();
- // 0-D vectors not supported.
- assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
- if (hasZeroDimVectors(shapeCastOp))
- return Value();
-
// Get the nth dimension size starting from lowest dimension.
auto getDimReverse = [](VectorType type, int64_t n) {
return type.getShape().take_back(n + 1).front();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5ae769090dac66..04518a56c3dd20 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -782,23 +782,23 @@ func.func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
// -----
-// CHECK-LABEL: fold_extract_shapecast_negative
-// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
-// CHECK: return %[[R]] : vector<4x2xf32>
-func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
- %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
- %r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
- return %r : vector<4x2xf32>
+// CHECK-LABEL: fold_extract_shapecast_0d_result
+// CHECK-SAME: %[[IN:.*]]: vector<1x1x1xf32>
+// CHECK: %[[R:.*]] = vector.extract %[[IN]][0, 0, 0] : f32 from vector<1x1x1xf32>
+// CHECK: return %[[R]] : f32
+func.func @fold_extract_shapecast_0d_result(%arg0 : vector<1x1x1xf32>) -> f32 {
+ %0 = vector.shape_cast %arg0 : vector<1x1x1xf32> to vector<f32>
+ %r = vector.extract %0[] : f32 from vector<f32>
+ return %r : f32
}
// -----
-// CHECK-LABEL: dont_fold_0d_extract_shapecast
-// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<f32> to vector<1xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+// CHECK-LABEL: fold_extract_shapecast_0d_source
+// CHECK-SAME: %[[IN:.*]]: vector<f32>
+// CHECK: %[[R:.*]] = vector.extract %[[IN]][] : f32 from vector<f32>
// CHECK: return %[[R]] : f32
-func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
+func.func @fold_extract_shapecast_0d_source(%arg0 : vector<f32>) -> f32 {
%0 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
%r = vector.extract %0[0] : f32 from vector<1xf32>
return %r : f32
@@ -806,11 +806,23 @@ func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
// -----
-// CHECK-LABEL: fold_extract_shapecast_to_shapecast
+// CHECK-LABEL: fold_extract_shapecast_negative
+// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
+// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
+// CHECK: return %[[R]] : vector<4x2xf32>
+func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
+ %r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
+ return %r : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: canonicalize_extract_shapecast_to_shapecast
// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
// CHECK: return %[[R]]
-func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
+func.func @canonicalize_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
%0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
%r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32>
return %r : vector<12xf32>
|
Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d Dropped the reverts on 3ad0148020ca91cc288bffd8ad36e25f7555a3bb and c02b8a01b7caf2e4ffe17a123f1bcf59192e4b39 after fixes upstream. Also carries a cherry pick for llvm/llvm-project#116650
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Just for viz., I was planning to send an RFC (EOW or ENW) to remove support for 0-D vectors from the vector type. The number of bugs we are dealing with and special-casing we are introducing to support 0-D vectors is not worth the value they bring (which it's still unclear to me). So, if you have any specific concerns about that idea, it would be great to know sooner than later (Discord?) :)
We can discuss on the RFC, but I think we need to have support for 0D vector. Without that the system doesnt seem to be closed. For example, tensor and memrefs allow 0D vectors, but having vectors being an outlier and not supporting 0D vector seems like a gap. We are having 0D vector issues because of it being supported patchily. Having fully rounded support and effective handling of 0D vector on lowering to LLVM is a better end state. |
Thank you for pointing this out! I think it would be helpful to provide concrete examples where avoiding That said, I wonder if strict consistency across these types is always desirable. After all, the reason for having different types (e.g., tensor, memref, vector) is to capture distinct use cases and abstractions, right? Personally, I see a different kind of inconsistency here. The ability to represent a scalar element in multiple ways ( I look forward to chatting more once the RFC is ready :) |
Agreed. We should collect examples where havig 0D vectors is useful, but a better place might be in response to the RFC.
Quick though here is that f32 and vector shouldnt matter in the end, but in the vector dialect itself if we dont have vector and instead rely on f32, it will cause a lot of bloat to keep the type consistent. For example some operations might be defined as operating on vector operands, but if a producer generates a f32 value there will be IR that is generated just to put the f32 into a vector<1xf32> , which is unnecessary. It just introduces artifacts into the code that makes analysis and transformations harder. Operations that "reduce" vectors could now generate a vector value or a scalar value depending on how the reduction is done (or IMO the worse solution of vector<1x1x...f32>). If the vector dialect is consistent in always generating vector types, and the lowering from vectors handles vector as effectively an f32, that seems to be the most consistent handling IMO. Sorry, I am prefetching some discussion that will happen again on the RFC, but food for thought :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG % one question
Still carrying revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650 --------- Signed-off-by: Simon Camphausen <[email protected]> Co-authored-by: Simon Camphausen <[email protected]>
That's a pretty large design change, no? Not that such things can't be justified, but remember: the vector dialect does not "own" the built-in vector type (and many would argue, is not even the primary user of the type). I think it is fine to propose something new, but at what point does this become "vector dialect2 -- now with its own type and specific opinions on dimensionality"? Just warning that I think such an RFC will be a quite difficult discussion and might not be profitable. |
From my perspective, having reviewed and refactored a fair amount of Vector patterns, I don’t think it represents a major design shift. That said, I could be overlooking something important, and this is precisely why I believe an RFC would be valuable - it would help clarify any potential concerns or implications we might not have considered.
In my view, the existing handling of 0-D vectors introduces a fair amount of confusion and misunderstanding. An RFC would create an opportunity to bring these issues to light, better define our needs as a community, and address any misalignments. Even if we ultimately decide to keep 0-D vectors, we could use the discussion to clarify ambiguities (of which there are many!) and improve consistency. Surely, that would benefit the community and strengthen the foundation we’re building on. |
My goal with the RFC was to gather feedback and provide some recommendations after several years dealing with this situation. It has been discussed multiple times, including with Nicolas, that 0-D vectors were introduced by inertia, without a full understanding of the implications. Over time, it has become evident that the level of complexity and trouble they introduced were definitely not expected. Something that it’s clear to me is that we can’t continue with the current instability and bug rate. The Vector dialect IR is ambiguous, and maintaining this ambiguity has proven over the years to making the situation even worse. If removing 0-D support from the built-in vector type is too controversial, excluding 0-D vectors from the Vector dialect could also be considered. Honestly, I’m trying to help here but not willing to bring more drama to Discourse. I think we have enough for now. I may let things settle for a while. |
I didn't mean to discourage the discussion. Was just trying to advise on not falling into the same discussion trap that has happened on the built-in vector type in the past. Scope the discussion to what is right for the vector dialect first, and then maybe see how that extends to built-in types. Do it the other way, and it will rathole on a consensus group that is doing a lot of different stuff. It's a relatively small group of active contributors to this area. Agreed on the project wide drama and bystander issues being a net negative. Perhaps start following the governance principles and be explicit about who the primary participants are for the limited scope being discussed. |
(FTR - I am a bystander on vector-0D and defer to the active contributors. I was primarily concerned with going directly at a redefinition of the built-in types. That is a very big discussion to have and historically starting there has caused the actual issue to not get discussed. My only anecdotal opinion on 0D is that it always seems tantalizing to either remove or embrace it, but either way, there tends to be a lot of degenerate cases and there is no way through but to get them defined and handled consistently) |
I see your point. To me, this is more about making Vector (dialect) and Vector Ops more opinionated in terms of
Agreed - consistency is crucial. Given the amount of special-casing required for 0-D vectors, I lean toward the idea of making Vector Ops "prefer" 1-D vectors over 0-D vectors. This could simplify handling and reduce ambiguities while maintaining flexibility where necessary. |
Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Signed-off-by: Jakub Kuderski <[email protected]>
…19287) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650.
Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Signed-off-by: Jakub Kuderski <[email protected]>
…19304) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650.
Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Signed-off-by: Jakub Kuderski <[email protected]>
…a0 (#19321) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. This time, we have some changes related to tablegen renaming in the vector dialect and op syntax changes in the bufferization dialect.
Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Signed-off-by: Jakub Kuderski <[email protected]>
…19334) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Removed `FieldParser`s for optional enums that get autogenerated as of llvm/llvm-project#117719.
Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Signed-off-by: Jakub Kuderski <[email protected]>
…19338) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Signed-off-by: Jakub Kuderski <[email protected]>
…19184) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d Dropped the reverts on 3ad0148020ca91cc288bffd8ad36e25f7555a3bb and c02b8a01b7caf2e4ffe17a123f1bcf59192e4b39 after fixes upstream. Also carries a cherry pick for llvm/llvm-project#116650
…19245) Still carrying revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650 --------- Signed-off-by: Simon Camphausen <[email protected]> Co-authored-by: Simon Camphausen <[email protected]>
…ree-org#19287) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650.
…ree-org#19304) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650.
…a0 (iree-org#19321) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. This time, we have some changes related to tablegen renaming in the vector dialect and op syntax changes in the bufferization dialect.
3e9c34d
to
e3419a9
Compare
…19184) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d Dropped the reverts on 3ad0148020ca91cc288bffd8ad36e25f7555a3bb and c02b8a01b7caf2e4ffe17a123f1bcf59192e4b39 after fixes upstream. Also carries a cherry pick for llvm/llvm-project#116650 Signed-off-by: Giacomo Serafini <[email protected]>
…19245) Still carrying revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650 --------- Signed-off-by: Simon Camphausen <[email protected]> Co-authored-by: Simon Camphausen <[email protected]> Signed-off-by: Giacomo Serafini <[email protected]>
…ree-org#19287) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Signed-off-by: Giacomo Serafini <[email protected]>
…ree-org#19304) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Signed-off-by: Giacomo Serafini <[email protected]>
…a0 (iree-org#19321) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. This time, we have some changes related to tablegen renaming in the vector dialect and op syntax changes in the bufferization dialect. Signed-off-by: Giacomo Serafini <[email protected]>
…ree-org#19334) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Removed `FieldParser`s for optional enums that get autogenerated as of llvm/llvm-project#117719. Signed-off-by: Giacomo Serafini <[email protected]>
…ree-org#19338) Still carrying a revert for 1004865f1ca41a9581da8747f34b29862d3ebc3d and a cherry pick for llvm/llvm-project#116650. Signed-off-by: Jakub Kuderski <[email protected]> Signed-off-by: Giacomo Serafini <[email protected]>
llvm#116650) The extract <-> shape cast folder was conservatively asserting and failing on 0-d vectors. This pr fixes this. This pr also adds more tests for 0d cases and updates related tests to better reflect what they test.
The extract <-> shape cast folder was conservatively asserting and failing on 0-d vectors. This pr fixes this.
This pr also adds more tests for 0d cases and updates related tests to better reflect what they test.