Skip to content

Commit

Permalink
[mlir][linalg] Add a folder for transpose(fill) -> fill (llvm#83623)
Browse files Browse the repository at this point in the history
This is similar to the existing folder for a linalg.copy. Transposing a
filled tensor is the same as filling the destination of the transpose.
  • Loading branch information
qedawkins authored Mar 2, 2024
1 parent f505a92 commit 205dce6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
18 changes: 17 additions & 1 deletion mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,22 @@ struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
}
};

/// Fold fill with transpose.
struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
rewriter.replaceOpWithNewOp<FillOp>(
transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
transposeOp.getDpsInitOperand(0)->get());
return success();
}
return failure();
}
};

} // namespace

void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
Expand All @@ -823,7 +839,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
.add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
FoldInsertPadIntoFill>(context);
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,20 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso

// -----

// CHECK-LABEL: func @canonicalize_fill_to_transpose_input(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK: %[[ZERO:.+]] = arith.constant 0.0
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
func.func @canonicalize_fill_to_transpose_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0.0 : f32
%fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
%transpose = linalg.transpose ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
return %transpose : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: func @broadcast_same_shape(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)
Expand Down

0 comments on commit 205dce6

Please sign in to comment.