Skip to content

Commit

Permalink
[mlir][shape] Turn ShapeOfOp folding into canonicalization pattern
Browse files Browse the repository at this point in the history
The `ShapeOfOp` folder used to generate invalid IR.

Input:
```
%0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex>
```

Output:
```
%0 = "shape.const_shape"() <{shape = dense<> : tensor<0xindex>}> : () -> tensor<?xindex>
error: 'shape.const_shape' op inferred type(s) 'tensor<0xindex>' are incompatible with return type(s) of operation 'tensor<?xindex>'
```

This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original `shape.shape_of` op had a return type of `tensor<?xindex>`, but the folded attribute (materialized as a `shape.const_shape` op) must have a type of `tensor<0xf32>` to be valid.

This commit fixes tests such as `mlir/test/Dialect/Shape/canonicalize.mlir` when verifying the IR after each pattern application (#74270).
  • Loading branch information
matthias-springer committed Dec 6, 2023
1 parent e68c265 commit 18ab550
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";

let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}

Expand Down
34 changes: 25 additions & 9 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
// ShapeOfOp
//===----------------------------------------------------------------------===//

OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
if (!type || !type.hasStaticShape())
return nullptr;
Builder builder(getContext());
return builder.getIndexTensorAttr(type.getShape());
}

namespace {
/// Replace shape_of(x) where x has a constant shape with a const_shape op.
struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
if (!type || !type.hasStaticShape())
return failure();
Location loc = op.getLoc();
Value constShape =
rewriter
.create<ConstShapeOp>(loc,
rewriter.getIndexTensorAttr(type.getShape()))
.getResult();
if (constShape.getType() != op.getResult().getType())
constShape = rewriter.create<tensor::CastOp>(
loc, op.getResult().getType(), constShape);
rewriter.replaceOp(op, constShape);
return success();
}
};

struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;

Expand Down Expand Up @@ -1739,7 +1754,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
ExtractFromShapeOfExtentTensor>(context);
ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
context);
}

LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Shape/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1492,3 +1492,15 @@ func.func @add_poison() -> !shape.size {
%result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size
return %result : !shape.size
}

// -----

// CHECK-LABEL: func @shape_of_0d(
// CHECK-SAME: %[[arg0:.*]]: tensor<f32>
// CHECK: %[[const:.*]] = shape.const_shape [] : tensor<0xindex>
// CHECK: %[[cast:.*]] = tensor.cast %[[const]] : tensor<0xindex> to tensor<?xindex>
// CHECK: return %[[cast]]
func.func @shape_of_0d(%arg0: tensor<f32>) -> tensor<?xindex> {
%0 = shape.shape_of %arg0 : tensor<f32> -> tensor<?xindex>
return %0 : tensor<?xindex>
}

0 comments on commit 18ab550

Please sign in to comment.