Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][shape] Turn ShapeOfOp folding into canonicalization pattern #74438

Merged
merged 1 commit into from
Dec 6, 2023

Conversation

matthias-springer
Copy link
Member

The ShapeOfOp folder used to generate invalid IR.

Input:

%0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex>

Output:

%0 = "shape.const_shape"() <{shape = dense<> : tensor<0xindex>}> : () -> tensor<?xindex>
error: 'shape.const_shape' op inferred type(s) 'tensor<0xindex>' are incompatible with return type(s) of operation 'tensor<?xindex>'

This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original shape.shape_of op had a return type of tensor<?xindex>, but the folded attribute (materialized as a shape.const_shape op) must have a type of tensor<0xf32> to be valid.

This commit fixes tests such as mlir/test/Dialect/Shape/canonicalize.mlir when verifying the IR after each pattern application (#74270).

@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2023

@llvm/pr-subscribers-mlir-shape

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The ShapeOfOp folder used to generate invalid IR.

Input:

%0 = shape.shape_of %arg1 : tensor&lt;index&gt; -&gt; tensor&lt;?xindex&gt;

Output:

%0 = "shape.const_shape"() &lt;{shape = dense&lt;&gt; : tensor&lt;0xindex&gt;}&gt; : () -&gt; tensor&lt;?xindex&gt;
error: 'shape.const_shape' op inferred type(s) 'tensor&lt;0xindex&gt;' are incompatible with return type(s) of operation 'tensor&lt;?xindex&gt;'

This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original shape.shape_of op had a return type of tensor&lt;?xindex&gt;, but the folded attribute (materialized as a shape.const_shape op) must have a type of tensor&lt;0xf32&gt; to be valid.

This commit fixes tests such as mlir/test/Dialect/Shape/canonicalize.mlir when verifying the IR after each pattern application (#74270).


Full diff: https://github.com/llvm/llvm-project/pull/74438.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td (-1)
  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+25-9)
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 3c9f45366fa2b..08a0398e74b0c 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -566,7 +566,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
   let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
 
   let hasCanonicalizer = 1;
-  let hasFolder = 1;
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2444556a45635..4f829db1305c8 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
 // ShapeOfOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
-  auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
-  if (!type || !type.hasStaticShape())
-    return nullptr;
-  Builder builder(getContext());
-  return builder.getIndexTensorAttr(type.getShape());
-}
-
 namespace {
+/// Replace shape_of(x) where x has a constant shape with a const_shape op.
+struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
+  using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
+                                PatternRewriter &rewriter) const override {
+    auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
+    if (!type || !type.hasStaticShape())
+      return failure();
+    Location loc = op.getLoc();
+    Value constShape =
+        rewriter
+            .create<ConstShapeOp>(loc,
+                                  rewriter.getIndexTensorAttr(type.getShape()))
+            .getResult();
+    if (constShape.getType() != op.getResult().getType())
+      constShape = rewriter.create<tensor::CastOp>(
+          loc, op.getResult().getType(), constShape);
+    rewriter.replaceOp(op, constShape);
+    return success();
+  }
+};
+
 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
 
@@ -1739,7 +1754,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
-               ExtractFromShapeOfExtentTensor>(context);
+               ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
+      context);
 }
 
 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(

Copy link
Member

@frgossen frgossen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this. Think a test would be good.

The `ShapeOfOp` folder used to generate invalid IR.

Input:
```
%0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex>
```

Output:
```
%0 = "shape.const_shape"() <{shape = dense<> : tensor<0xindex>}> : () -> tensor<?xindex>
error: 'shape.const_shape' op inferred type(s) 'tensor<0xindex>' are incompatible with return type(s) of operation 'tensor<?xindex>'
```

This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original `shape.shape_of` op had a return type of `tensor<?xindex>`, but the folded attribute (materialized as a `shape.const_shape` op) must have a type of `tensor<0xf32>` to be valid.

This commit fixes tests such as `mlir/test/Dialect/Shape/canonicalize.mlir` when verifying the IR after each pattern application (llvm#74270).
@matthias-springer matthias-springer merged commit dbb782d into llvm:main Dec 6, 2023
3 of 4 checks passed
mlevesquedion pushed a commit to mlevesquedion/stablehlo that referenced this pull request Dec 7, 2023
In llvm/llvm-project#74438, the folder for `shape.shape_of` is changed to a canonicalizer. This means that constant shapes no longer get folded automatically (`--canonicalize` must be used). This will cause a test failure when we do the next LLVM integrate, because the `broadcast_select_reify` test expects the `shape.shape_of` operation to be folded into `shape.const_shape`. The test also expects the constant shape value to be pushed to the rightmost arg of the `shape.broadcast` operation, which will not be the case if canonicalization is not applied.

Additional context:
- The old folder for `shape.shape_of` returned its input shape as a
  tensor attribute, so it would [automatically get materialized](https://mlir.llvm.org/docs/Canonicalization/#generating-constants-from-attributes) [to a `shape.const_shape` op](https://github.com/llvm/llvm-project/blob/98d8dce6e9e21a995f6a06fa4485fa529931be37/mlir/lib/Dialect/Shape/IR/Shape.cpp#L154-L156).
- The new canonicalizer does the materialization explicitly.
- [BroadcastOp is Commutative](https://github.com/llvm/llvm-project/blob/7ddd3d776402f9cc7d5f13b5940ba38a285223c2/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td#L57) and [ConstShape is ConstantLike](https://github.com/llvm/llvm-project/blob/7ddd3d776402f9cc7d5f13b5940ba38a285223c2/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td#L105), so
  if `shape.shape_of` is folded to `shape.const_shape`, the resulting
  value becomes the rightmost argument to `shape.broadcast`. Indeed,
  according to the docs [constant arguments of commutative ops are
  shifted to the
  right](https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules:~:text=Move%20constant%20operands%20to%20commutative%20operators%20to%20the%20right%20side), and this is implemented [here](https://github.com/llvm/llvm-project/blob/7ddd3d776402f9cc7d5f13b5940ba38a285223c2/mlir/lib/IR/Operation.cpp#L802).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants