From a461e0a36b85b7772e61191591ebd31fe0d632e9 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 21 Sep 2023 16:22:32 -0700 Subject: [PATCH] [mlir][TilingInterface] NFC code changes separated out from introduction of `scf::tileUsingSCFForallop`. This patch contains NFC changes that are precursor to the introduction of `scf::tileUsingSCFForallOp` method. --- .../SCF/Transforms/TileUsingInterface.h | 4 +- .../TransformOps/LinalgTransformOps.cpp | 14 +-- .../SCF/Transforms/TileUsingInterface.cpp | 112 +++++++++++------- .../tile-and-fuse-using-interface.mlir | 18 +-- .../tile-fuse-and-yield-using-interface.mlir | 2 +- .../tile-pad-using-interface.mlir | 12 +- .../TilingInterface/tile-using-interface.mlir | 14 +-- .../TilingInterface/TestTilingInterface.cpp | 99 ++++++++-------- 8 files changed, 148 insertions(+), 127 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index ca641c596c7b7bb..9f49d97e141e0c8 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -60,7 +60,7 @@ struct SCFTilingResult { /// of the last op. SmallVector tiledOps; /// The `scf.for` operations that iterate over the tiles. - SmallVector loops; + SmallVector loops; /// Values to use as replacements for the untiled op. Is the same size as the /// number of results of the untiled op. SmallVector replacements; @@ -160,7 +160,7 @@ struct SCFTileAndFuseResult { /// generated operation. llvm::SetVector tiledAndFusedOps; /// The `scf.for` operations that iterate over the tiles. - SmallVector loops; + SmallVector loops; /// The replacement values to use for the tiled and fused operations. llvm::DenseMap replacements; }; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 1819ca614a060fd..ca3db7401e38caa 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -434,16 +434,12 @@ static LogicalResult applyTilingToAll( SmallVector opsToReplace{target}; llvm::append_range(opsToReplace, tiledResults->fusedProducers); for (Operation *toReplace : opsToReplace) { - SmallVector replacements; - replacements.reserve(toReplace->getNumResults()); - for (OpResult res : toReplace->getResults()) { - auto it = tiledResults->replacements.find(res); - if (it == tiledResults->replacements.end()) - replacements.push_back(res); - else - replacements.push_back(it->getSecond()); + for (OpResult res : toReplace->getResults()) + if (auto replacement = tiledResults->replacements.lookup(res)) + rewriter.replaceAllUsesWith(res, replacement); + if (toReplace->use_empty()) { + rewriter.eraseOp(toReplace); } - rewriter.replaceOp(toReplace, replacements); } // Report back the relevant handles to the transform op. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 6cfba3fef15ebda..c291f26f4c1b38c 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef interchangeVector, return filledVector; } +/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`. +template +static SmallVector getAsOperations(ArrayRef ops) { + return llvm::to_vector( + llvm::map_range(ops, [](auto op) -> Operation * { return op; })); +} +template +static SmallVector +getAsOperations(const SmallVector &ops) { + return getAsOperations(ArrayRef(ops)); +} + +/// Convert a list of `Operation *` to a list of `DstOpTy. +template +static SmallVector castToTypedOperations(ArrayRef ops) { + return llvm::to_vector( + llvm::map_range(ops, [](Operation *op) { return cast(op); })); +} +template +static SmallVector +castToTypedOperations(const SmallVector &ops) { + return castToTypedOperations(ArrayRef(ops)); +} + //===----------------------------------------------------------------------===// // tileUsingSCFForOp implementation. //===----------------------------------------------------------------------===// @@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) { /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, Value iv, - Value tileSize) { - std::optional ts = getConstantIntValue(tileSize); - if (ts && ts.value() == 1) - return getAsOpFoldResult(tileSize); + OpFoldResult tileSize) { + if (isConstantIntValue(tileSize, 1)) + return tileSize; if (tileDividesIterationDomain( Range{loopRange.offset, loopRange.size, tileSize})) @@ -295,8 +318,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, tileSizeVector.append(numLoops - tileSizeVector.size(), zero); } - scf::SCFTilingResult tilingResult; SmallVector offsets, sizes; + SmallVector forLoops; { // If there is an interchange specified, permute the iteration domain and // the tile sizes. @@ -319,8 +342,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, // 3. Materialize an empty loop nest that iterates over the tiles. These // loops for now do not return any values even if the original operation has // results. - tilingResult.loops = generateTileLoopNest( - rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); + forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain, + tileSizeVector, offsets, sizes); if (!interchangeVector.empty()) { auto inversePermutation = invertPermutationVector(interchangeVector); @@ -330,30 +353,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, } LLVM_DEBUG({ - if (!tilingResult.loops.empty()) { + if (!forLoops.empty()) { llvm::dbgs() << "LoopNest shell :\n"; - tilingResult.loops.front().dump(); + forLoops.front().dump(); llvm::dbgs() << "\n"; } }); // 4. Generate the tiled implementation within the inner most loop. - if (!tilingResult.loops.empty()) - rewriter.setInsertionPoint( - tilingResult.loops.back().getBody()->getTerminator()); + if (!forLoops.empty()) + rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator()); FailureOr tiledImplementation = op.getTiledImplementation(rewriter, offsets, sizes); - tilingResult.tiledOps.append(tiledImplementation->tiledOps); + if (op->getNumResults() == 0) { - // nothing more to do. - return tilingResult; + return scf::SCFTilingResult{ + tiledImplementation->tiledOps, getAsOperations(forLoops), {}}; } // If loops are empty, the tiled op is used as the replacement for the untiled // op. - if (tilingResult.loops.empty()) { - tilingResult.replacements = tiledImplementation->tiledValues; - return tilingResult; + if (forLoops.empty()) { + return scf::SCFTilingResult{tiledImplementation->tiledOps, + getAsOperations(forLoops), + tiledImplementation->tiledValues}; } // 5. Yield all the results of the tiled operation. The surrounding loop @@ -377,18 +400,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, destinationTensors))) return rewriter.notifyMatchFailure(op, "failed to get destinations"); - tilingResult.replacements = yieldTiledValues( + SmallVector replacements = yieldTiledValues( rewriter, destinationTensors, tiledImplementation.value(), - resultOffsetsList, resultSizesList, tilingResult.loops); - + resultOffsetsList, resultSizesList, forLoops); LLVM_DEBUG({ - if (!tilingResult.loops.empty()) { + if (!forLoops.empty()) { llvm::dbgs() << "After tiled implementation :\n"; - tilingResult.loops.front().dump(); + forLoops.front().dump(); llvm::dbgs() << "\n"; } }); - return tilingResult; + return scf::SCFTilingResult{tiledImplementation->tiledOps, + getAsOperations(forLoops), replacements}; } FailureOr @@ -466,6 +489,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, results.mergeOp = mergeOp; return results; } + //===----------------------------------------------------------------------===// // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. //===----------------------------------------------------------------------===// @@ -636,7 +660,9 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( } // 1. First tile the consumer. - scf::SCFTileAndFuseResult tileAndFuseResult; + SmallVector forLoops; + SetVector fusedProducers, tiledAndFusedOps; + DenseMap replacements; llvm::SmallDenseMap yieldedValueToResultNumber; { FailureOr tilingResult = @@ -644,20 +670,21 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( if (failed(tilingResult)) return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); for (auto *tiledOp : tilingResult->tiledOps) - tileAndFuseResult.tiledAndFusedOps.insert(tiledOp); - tileAndFuseResult.loops = std::move(tilingResult->loops); - for (const auto &result : llvm::enumerate( - llvm::zip(consumer->getResults(), tilingResult->replacements))) { - tileAndFuseResult.replacements[std::get<0>(result.value())] = - std::get<1>(result.value()); + tiledAndFusedOps.insert(tiledOp); + forLoops = castToTypedOperations(tilingResult->loops); + for (auto [index, origValue, replacement] : + llvm::enumerate(consumer->getResults(), tilingResult->replacements)) { + replacements[origValue] = replacement; yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( - result.index())] = result.index(); + index)] = index; } } // If there are no loops generated, fusion is immaterial. - if (tileAndFuseResult.loops.empty()) - return tileAndFuseResult; + if (forLoops.empty()) { + return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, + getAsOperations(forLoops), replacements}; + } // 2. Typically, the operands of the tiled operation are slices of the // operands of the untiled operation. These are expressed in IR using @@ -674,7 +701,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( }; std::deque candidates; - addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); + addCandidateSlices(tiledAndFusedOps.back(), candidates); OpBuilder::InsertionGuard g(rewriter); while (!candidates.empty()) { // Traverse the slices in BFS fashion. @@ -684,19 +711,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( // The operands of the fused producer might themselved be slices of // values produced by operations that implement the `TilingInterface`. // Add these operations to the worklist. - std::optional fusedProducer = - tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, - tileAndFuseResult.loops); - if (!fusedProducer) + std::optional fusedResult = + tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops); + if (!fusedResult) continue; if (Operation *tiledAndFusedOp = - fusedProducer->tiledAndFusedProducer.getDefiningOp()) { - tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); + fusedResult->tiledAndFusedProducer.getDefiningOp()) { + fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); + tiledAndFusedOps.insert(tiledAndFusedOp); addCandidateSlices(tiledAndFusedOp, candidates); } } - return tileAndFuseResult; + return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, + getAsOperations(forLoops), replacements}; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir index 4f5900fda3e76bd..cf5a1b828f95b75 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -8,7 +8,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> %d1 = tensor.dim %arg1, %c1 : tensor %init = tensor.empty(%d0, %d1) : tensor %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor - %gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"} + %gemm = linalg.matmul {__internal_transform__ = "fusion"} ins(%arg0, %arg1 : tensor, tensor) outs(%fill : tensor) -> tensor return %gemm : tensor @@ -47,7 +47,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor, %arg1 : tensor, ins(%arg0, %arg1 : tensor, tensor) outs(%fill : tensor) -> tensor %generic = linalg.generic { - __internal_linalg_transform__ = "fusion", + __internal_transform__ = "fusion", indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%gemm, %arg2 : tensor, tensor) outs(%init : tensor) { @@ -97,7 +97,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor, %rhs0 : tensor, %r %d2 = tensor.dim %rhs1, %c1 : tensor %init1 = tensor.empty(%d0, %d2) : tensor %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor - %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"} + %gemm1 = linalg.matmul {__internal_transform__ = "gemm_fusion"} ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor return %gemm1 : tensor } @@ -147,7 +147,7 @@ func.func @gemm_transpose_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor %init1 = tensor.empty(%d1, %d0) : tensor %transpose = linalg.generic { - __internal_linalg_transform__ = "fusion", + __internal_transform__ = "fusion", indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%gemm : tensor) outs(%init1 : tensor) { @@ -198,7 +198,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor, %arg1 : tensor, tensor) outs(%1 : tensor) -> tensor %3 = linalg.generic { - __internal_linalg_transform__ = "gemm_interchange_fusion", + __internal_transform__ = "gemm_interchange_fusion", indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor) outs(%0 : tensor) { @@ -249,7 +249,7 @@ func.func @matmul_plus_matmul(%arg0: tensor, %arg1: tensor, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], - __internal_linalg_transform__ = "gemm_plus_gemm_fusion"} + __internal_transform__ = "gemm_plus_gemm_fusion"} ins(%2, %2 : tensor, tensor) outs(%5 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : @@ -302,7 +302,7 @@ func.func @matmul_plus_transpose_matmul(%arg0: tensor, %arg1: tensor (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], - __internal_linalg_transform__ = "gemm_plus_gemm_fusion"} + __internal_transform__ = "gemm_plus_gemm_fusion"} ins(%2, %2 : tensor, tensor) outs(%5 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : @@ -352,7 +352,7 @@ func.func @matmul_sequence_fusion(%arg0: tensor, %arg1: tensor %1 = linalg.matmul ins(%0, %arg3 : tensor, tensor) outs(%arg4 : tensor) -> tensor // [M, N1] * [N1, N2] %2 = linalg.matmul - {__internal_linalg_transform__ = "gemm_sequence_fusion"} + {__internal_transform__ = "gemm_sequence_fusion"} ins(%1, %arg5 : tensor, tensor) outs(%arg6 : tensor) -> tensor // [M, N2] * [N2, N3] return %2 : tensor @@ -425,7 +425,7 @@ func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> { linalg.yield %10, %9 : f32, f32 } -> (tensor<30xf32>, tensor<30x3xf32>) %6 = linalg.generic { - __internal_linalg_transform__ = "reduction_sequence_fusion", + __internal_transform__ = "reduction_sequence_fusion", indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir index f47850a5cb6d229..f725d19e14a0c5b 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir @@ -13,7 +13,7 @@ func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor, %rhs0 : tensor, tensor) outs(%fill0 : tensor) -> tensor %d2 = tensor.dim %rhs1, %c1 : tensor %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor - %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_sequence_fusion_and_yield"} + %gemm1 = linalg.matmul {__internal_transform__ = "gemm_sequence_fusion_and_yield"} ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor return %gemm0, %gemm1 : tensor, tensor } diff --git a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir index 2d6069973c8bf78..cbc5d6c186d6d34 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir @@ -6,7 +6,7 @@ func.func @dynamic_2d_pad_tensor(%input_tensor: tensor, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_linalg_transform__ = "pad_2dtiling"}: tensor to tensor + } {__internal_transform__ = "pad_2dtiling"}: tensor to tensor return %0 : tensor } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 8)> @@ -38,7 +38,7 @@ func.func @dynamic_2d_pad_tensor_inner_tiling(%input_tensor: tensor, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_linalg_transform__ = "pad_inner_tiling"}: tensor to tensor + } {__internal_transform__ = "pad_inner_tiling"}: tensor to tensor return %0 : tensor } // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)> @@ -68,7 +68,7 @@ func.func @static_pad_tensor(%input_tensor: tensor<7x9xf32>, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_linalg_transform__ = "pad_2dtiling"} : tensor<7x9xf32> to tensor<15x16xf32> + } {__internal_transform__ = "pad_2dtiling"} : tensor<7x9xf32> to tensor<15x16xf32> return %0 : tensor<15x16xf32> } // CHECK-LABEL: func @static_pad_tensor( @@ -95,7 +95,7 @@ func.func @static_pad_tensor_inner_tiling(%input_tensor: tensor<7x9xf32>, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_linalg_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32> + } {__internal_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32> return %0 : tensor<15x16xf32> } // CHECK-LABEL: func @static_pad_tensor_inner_tiling( @@ -122,7 +122,7 @@ func.func @dynamic_2d_pad_tensor_outer_tiling(%input_tensor: tensor, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_linalg_transform__ = "pad_outer_tiling"}: tensor to tensor + } {__internal_transform__ = "pad_outer_tiling"}: tensor to tensor return %0 : tensor } // CHECK-LABEL: func @dynamic_2d_pad_tensor_outer_tiling @@ -134,7 +134,7 @@ func.func @static_pad_tensor_outer_tiling(%input_tensor: tensor<7x9xf32>, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_linalg_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32> + } {__internal_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32> return %0 : tensor<15x16xf32> } // CHECK-LABEL: func @static_pad_tensor_outer_tiling diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir index cacef3c47b5e1cd..2153eb6f237fcfd 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -2,7 +2,7 @@ func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"} + %0 = linalg.matmul {__internal_transform__ = "simple_gemm"} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor @@ -45,7 +45,7 @@ func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, func.func @simple_matmul_memref(%arg0 : memref, %arg1 : memref, %arg2 : memref) { - linalg.matmul {__internal_linalg_transform__ = "simple_gemm_memref"} + linalg.matmul {__internal_transform__ = "simple_gemm_memref"} ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) return @@ -92,7 +92,7 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x %0:2 = linalg.generic { indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} - {__internal_linalg_transform__ = "parallel_generic_transpose"} + {__internal_transform__ = "parallel_generic_transpose"} ins(%arg0 : tensor<128x200x300xf32>) outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) { ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): @@ -139,7 +139,7 @@ func.func @conv2D(%arg0 : tensor, %arg1 : tensor, %0 = linalg.conv_2d_nhwc_hwcf { strides = dense<[2, 3]> : tensor<2xi64>, dilation = dense<[4, 5]> : tensor<2xi64>, - __internal_linalg_transform__ = "simple_conv"} + __internal_transform__ = "simple_conv"} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor @@ -205,7 +205,7 @@ func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} - {__internal_linalg_transform__ = "indexed_semantics"} + {__internal_transform__ = "indexed_semantics"} ins(%arg0: tensor) outs(%arg1: tensor) { ^bb0(%arg2: f32, %arg3: f32): @@ -229,7 +229,7 @@ func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> func.func @interchange_matmul(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = linalg.matmul {__internal_linalg_transform__ = "gemm_interchange"} + %0 = linalg.matmul {__internal_transform__ = "gemm_interchange"} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor @@ -283,7 +283,7 @@ func.func @interchange_matmul(%arg0 : tensor, %arg1 : tensor, // CHECK: memref.subview // CHECK: linalg.copy func.func @linalg_copy_matmul(%a: memref, %b: memref) { - linalg.copy {__internal_linalg_transform__ = "simple_copy_memref"} + linalg.copy {__internal_transform__ = "simple_copy_memref"} ins(%a : memref) outs(%b : memref) return } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp index 2fcc7bcadb60450..2573e11979dbc47 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -37,51 +37,51 @@ using namespace mlir; namespace { /// Marker used as attribute name in generated Linalg rewriting transformations. -const StringLiteral kLinalgTransformMarker = "__internal_linalg_transform__"; +const StringLiteral kTransformMarker = "__internal_transform__"; /// Helper class to control application of linalg transformation patterns. /// Control comes in 2 forms: /// 1. attribute matching and setting behavior using the attribute named -/// `kLinalgTransformMarker`. This can be used to build a state machine +/// `kTransformMarker`. This can be used to build a state machine /// using attributes and incrementally applying patterns to advance states. /// 2. filter function, which is a simple lambda on the Operation* that /// returns a LogicalResult. -struct LinalgTransformationFilter { +struct TransformationFilter { using FilterFunction = std::function; - explicit LinalgTransformationFilter( + explicit TransformationFilter( ArrayRef matchDisjunction = {}, std::optional replacement = std::nullopt); - explicit LinalgTransformationFilter( + explicit TransformationFilter( const FilterFunction &f, ArrayRef matchDisjunction = {}, std::optional replacement = std::nullopt); - LinalgTransformationFilter(LinalgTransformationFilter &&) = default; - LinalgTransformationFilter(const LinalgTransformationFilter &) = default; + TransformationFilter(TransformationFilter &&) = default; + TransformationFilter(const TransformationFilter &) = default; LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; - void replaceLinalgTransformationFilter(PatternRewriter &rewriter, - Operation *op) const; + void replaceTransformationFilter(PatternRewriter &rewriter, + Operation *op) const; - LinalgTransformationFilter &addFilter(const FilterFunction &f) { + TransformationFilter &addFilter(const FilterFunction &f) { if (f) filters.push_back(f); return *this; } template - LinalgTransformationFilter &addOpFilter() { + TransformationFilter &addOpFilter() { return addFilter( [](Operation *op) { return success(isa(op)); }); } - LinalgTransformationFilter &addOpNameFilter(StringRef opName) { + TransformationFilter &addOpNameFilter(StringRef opName) { return addFilter([opName](Operation *op) { return success(op->getName().getStringRef() == opName); }); } - LinalgTransformationFilter &setMatchByDefault() { + TransformationFilter &setMatchByDefault() { matchByDefault = true; return *this; } @@ -95,20 +95,19 @@ struct LinalgTransformationFilter { bool matchByDefault; }; -LinalgTransformationFilter::LinalgTransformationFilter( +TransformationFilter::TransformationFilter( ArrayRef matchDisjunction, std::optional replacement) : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), replacement(replacement), matchByDefault(false) {} -LogicalResult -LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter, - Operation *op) const { +LogicalResult TransformationFilter::checkAndNotify(PatternRewriter &rewriter, + Operation *op) const { if (llvm::any_of(filters, [&](const FilterFunction &f) { return failed(f(op)); })) return failure(); - auto attr = op->template getAttrOfType(kLinalgTransformMarker); + auto attr = op->template getAttrOfType(kTransformMarker); if (!attr) { // 1. Has no filter case and matchDisjunction is empty. @@ -134,12 +133,12 @@ LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter, }); } -void LinalgTransformationFilter::replaceLinalgTransformationFilter( +void TransformationFilter::replaceTransformationFilter( PatternRewriter &rewriter, Operation *op) const { if (replacement.has_value()) - op->setAttr(kLinalgTransformMarker, *replacement); + op->setAttr(kTransformMarker, *replacement); else - op->removeAttr(rewriter.getStringAttr(kLinalgTransformMarker)); + op->removeAttr(rewriter.getStringAttr(kTransformMarker)); } /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using @@ -147,18 +146,17 @@ void LinalgTransformationFilter::replaceLinalgTransformationFilter( /// using a `filter` to avoid recursive application. struct TestTileUsingSCFForOp : public OpInterfaceRewritePattern { - TestTileUsingSCFForOp( - MLIRContext *context, scf::SCFTilingOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1) + TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options, + TransformationFilter filter = TransformationFilter(), + PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), options(std::move(options)), filter(std::move(filter)) {} /// Construct a generic pattern applied to `opName`. - TestTileUsingSCFForOp( - StringRef opName, MLIRContext *context, scf::SCFTilingOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1) + TestTileUsingSCFForOp(StringRef opName, MLIRContext *context, + scf::SCFTilingOptions options, + TransformationFilter filter = TransformationFilter(), + PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), options(std::move(options)), filter(std::move(filter)) {} @@ -179,13 +177,13 @@ struct TestTileUsingSCFForOp } for (auto *tiledOp : tilingResult->tiledOps) - filter.replaceLinalgTransformationFilter(rewriter, tiledOp); + filter.replaceTransformationFilter(rewriter, tiledOp); return success(); } private: scf::SCFTilingOptions options; - LinalgTransformationFilter filter; + TransformationFilter filter; }; /// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern @@ -196,7 +194,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp : public OpInterfaceRewritePattern { TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp( MLIRContext *context, scf::SCFTileAndFuseOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + TransformationFilter filter = TransformationFilter(), PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), options(std::move(options)), filter(std::move(filter)) {} @@ -205,7 +203,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp( StringRef opName, MLIRContext *context, scf::SCFTileAndFuseOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + TransformationFilter filter = TransformationFilter(), PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), options(std::move(options)), filter(std::move(filter)) {} @@ -229,14 +227,14 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp } rewriter.replaceOp(op, replacements); - filter.replaceLinalgTransformationFilter( + filter.replaceTransformationFilter( rewriter, tileAndFuseResult->tiledAndFusedOps.front()); return success(); } private: scf::SCFTileAndFuseOptions options; - LinalgTransformationFilter filter; + TransformationFilter filter; }; /// Pattern to tile a consumer and fuse producer with it @@ -254,7 +252,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp TestTileConsumerFuseAndYieldProducerUsingSCFForOp( MLIRContext *context, scf::SCFTilingOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + TransformationFilter filter = TransformationFilter(), PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), options(std::move(options)), filter(std::move(filter)) {} @@ -302,6 +300,8 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp std::deque candidates; addCandidateSlices(tilingResult->tiledOps.back(), candidates); OpBuilder::InsertionGuard g(rewriter); + auto forLoops = llvm::to_vector(llvm::map_range( + tilingResult->loops, [](auto op) { return cast(op); })); while (!candidates.empty()) { // Traverse the slices in BFS fashion. tensor::ExtractSliceOp candidateSliceOp = candidates.front(); @@ -309,8 +309,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp // Materialize the slice of the producer in place. std::optional fusedProducer = - tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, - tilingResult->loops); + tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops); if (!fusedProducer) continue; @@ -318,11 +317,10 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp // to be yielded from within the tiled loop. OpResult untiledProducer = fusedProducer->origProducer; if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) { - return !isIgnoredUser(user, tilingResult->loops.front()); + return !isIgnoredUser(user, forLoops.front()); })) { yieldReplacementForFusedProducer(rewriter, candidateSliceOp, - fusedProducer.value(), - tilingResult->loops); + fusedProducer.value(), forLoops); yieldedValuesToOrigValues.push_back(untiledProducer); } @@ -332,7 +330,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp addCandidateSlices(fusedProducerOp, candidates); } - scf::ForOp outermostLoop = tilingResult->loops.front(); + scf::ForOp outermostLoop = forLoops.front(); for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) { Value replacement = outermostLoop.getResult(index); rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) { @@ -340,8 +338,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp }); } rewriter.eraseOp(rootOp); - filter.replaceLinalgTransformationFilter(rewriter, - tilingResult->tiledOps.back()); + filter.replaceTransformationFilter(rewriter, tilingResult->tiledOps.back()); return success(); } @@ -370,7 +367,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp } scf::SCFTilingOptions options; - LinalgTransformationFilter filter; + TransformationFilter filter; }; /// Pattern to lower operations that implement the `TilingInterface` to @@ -453,8 +450,8 @@ static void addPatternForTiling(MLIRContext *context, SmallVector tileSizesOfr = getAsIndexOpFoldResult(context, tileSizes); tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange); - LinalgTransformationFilter filter(StringAttr::get(context, filterName), - StringAttr::get(context, "tiled")); + TransformationFilter filter(StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); patterns.add(context, tilingOptions, filter); } @@ -467,8 +464,8 @@ static void addPatternForTileFuseAndYield(MLIRContext *context, SmallVector tileSizesOfr = getAsIndexOpFoldResult(context, tileSizes); tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange); - LinalgTransformationFilter filter(StringAttr::get(context, filterName), - StringAttr::get(context, "tiled")); + TransformationFilter filter(StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); patterns.add( context, tilingOptions, filter); } @@ -483,8 +480,8 @@ static void addPatternForTileAndFuse(MLIRContext *context, getAsIndexOpFoldResult(context, tileSizes); tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr) .setInterchange(interchange); - LinalgTransformationFilter filter(StringAttr::get(context, filterName), - StringAttr::get(context, "tiled")); + TransformationFilter filter(StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); patterns.add( context, tileAndFuseOptions, filter); }