-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][shape] Turn ShapeOfOp
folding into canonicalization pattern
#74438
[mlir][shape] Turn ShapeOfOp
folding into canonicalization pattern
#74438
Conversation
@llvm/pr-subscribers-mlir-shape @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe Input:
Output:
This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original This commit fixes tests such as Full diff: https://github.com/llvm/llvm-project/pull/74438.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 3c9f45366fa2b..08a0398e74b0c 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -566,7 +566,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
let hasCanonicalizer = 1;
- let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2444556a45635..4f829db1305c8 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
// ShapeOfOp
//===----------------------------------------------------------------------===//
-OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
- auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
- if (!type || !type.hasStaticShape())
- return nullptr;
- Builder builder(getContext());
- return builder.getIndexTensorAttr(type.getShape());
-}
-
namespace {
+/// Replace shape_of(x) where x has a constant shape with a const_shape op.
+struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
+ using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(shape::ShapeOfOp op,
+ PatternRewriter &rewriter) const override {
+ auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
+ if (!type || !type.hasStaticShape())
+ return failure();
+ Location loc = op.getLoc();
+ Value constShape =
+ rewriter
+ .create<ConstShapeOp>(loc,
+ rewriter.getIndexTensorAttr(type.getShape()))
+ .getResult();
+ if (constShape.getType() != op.getResult().getType())
+ constShape = rewriter.create<tensor::CastOp>(
+ loc, op.getResult().getType(), constShape);
+ rewriter.replaceOp(op, constShape);
+ return success();
+ }
+};
+
struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
@@ -1739,7 +1754,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
- ExtractFromShapeOfExtentTensor>(context);
+ ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
+ context);
}
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this. Think a test would be good.
The `ShapeOfOp` folder used to generate invalid IR. Input: ``` %0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex> ``` Output: ``` %0 = "shape.const_shape"() <{shape = dense<> : tensor<0xindex>}> : () -> tensor<?xindex> error: 'shape.const_shape' op inferred type(s) 'tensor<0xindex>' are incompatible with return type(s) of operation 'tensor<?xindex>' ``` This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original `shape.shape_of` op had a return type of `tensor<?xindex>`, but the folded attribute (materialized as a `shape.const_shape` op) must have a type of `tensor<0xf32>` to be valid. This commit fixes tests such as `mlir/test/Dialect/Shape/canonicalize.mlir` when verifying the IR after each pattern application (llvm#74270).
8be89bf
to
18ab550
Compare
In llvm/llvm-project#74438, the folder for `shape.shape_of` is changed to a canonicalizer. This means that constant shapes no longer get folded automatically (`--canonicalize` must be used). This will cause a test failure when we do the next LLVM integrate, because the `broadcast_select_reify` test expects the `shape.shape_of` operation to be folded into `shape.const_shape`. The test also expects the constant shape value to be pushed to the rightmost arg of the `shape.broadcast` operation, which will not be the case if canonicalization is not applied. Additional context: - The old folder for `shape.shape_of` returned its input shape as a tensor attribute, so it would [automatically get materialized](https://mlir.llvm.org/docs/Canonicalization/#generating-constants-from-attributes) [to a `shape.const_shape` op](https://github.com/llvm/llvm-project/blob/98d8dce6e9e21a995f6a06fa4485fa529931be37/mlir/lib/Dialect/Shape/IR/Shape.cpp#L154-L156). - The new canonicalizer does the materialization explicitly. - [BroadcastOp is Commutative](https://github.com/llvm/llvm-project/blob/7ddd3d776402f9cc7d5f13b5940ba38a285223c2/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td#L57) and [ConstShape is ConstantLike](https://github.com/llvm/llvm-project/blob/7ddd3d776402f9cc7d5f13b5940ba38a285223c2/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td#L105), so if `shape.shape_of` is folded to `shape.const_shape`, the resulting value becomes the rightmost argument to `shape.broadcast`. Indeed, according to the docs [constant arguments of commutative ops are shifted to the right](https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules:~:text=Move%20constant%20operands%20to%20commutative%20operators%20to%20the%20right%20side), and this is implemented [here](https://github.com/llvm/llvm-project/blob/7ddd3d776402f9cc7d5f13b5940ba38a285223c2/mlir/lib/IR/Operation.cpp#L802).
The
ShapeOfOp
folder used to generate invalid IR.Input:
Output:
This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original
shape.shape_of
op had a return type oftensor<?xindex>
, but the folded attribute (materialized as ashape.const_shape
op) must have a type oftensor<0xf32>
to be valid.This commit fixes tests such as
mlir/test/Dialect/Shape/canonicalize.mlir
when verifying the IR after each pattern application (#74270).