Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Handle arith.const expr in dispatchIndexOpFoldResult func #122432

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
staticVec.push_back(apInt.getSExtValue());
return;
}

OpFoldResult result = getAsOpFoldResult(v);
if (auto attr = result.dyn_cast<Attribute>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could dyn_cast the OpFoldResult directly as IntegerAttr.
saves forced cast on next line.
some thing like if (auto iattr = dyn_cast<IntegerAttr>(result)

APInt apInt = cast<IntegerAttr>(attr).getValue();
staticVec.push_back(apInt.getSExtValue());
return;
}

dynamicVec.push_back(v);
staticVec.push_back(ShapedType::kDynamic);
}
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Tensor/bubble-reshapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,26 @@ func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1:

// -----

func.func @bubble_parallel_reshapes2(%arg0: tensor<?x2x2x6xf32>, %s0: index, %s1: index) -> tensor<?x4x2x3xf32> {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not that familiar with this op anymore but I expected output_shape [%s0, 4, 2, 3].

Copy link
Member

@matthias-springer matthias-springer Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm surprised that the verifier allows this op. @MaheshRavishankar to clarify.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, probably missing a canonicalizer here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expected the output_shape to match the result type. I.e., an output dim must be static iff the respective dim is static in the result type. Is that not the case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The verifier doesnt enforce it (would be wrong to do so... its not necessarily wrong IR), but we could convert the dynamic values in output_shape to static values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with other tensor dialect ops, I would recommend to make the verifier stricter. See discussion here.

return %expand : tensor<?x4x2x3xf32>
}
// CHECK: func @bubble_parallel_reshapes2
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x2x2x6xf32>
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK-SAME: output_shape [%[[S0]], 2, 2, %[[C2]], %[[C3]]] : tensor<?x2x2x6xf32> into tensor<?x2x2x2x3xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x2x2x2x3xf32> into tensor<?x4x2x3xf32>
// CHECK: return %[[COLLAPSE]]

// -----

func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
Expand Down
Loading