Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][TilingInterface] NFC code changes separated out from introduction of scf::tileUsingSCFForallop. #67081

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct SCFTilingResult {
/// of the last op.
SmallVector<Operation *> tiledOps;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<scf::ForOp> loops;
SmallVector<Operation *> loops;
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
Expand Down Expand Up @@ -160,7 +160,7 @@ struct SCFTileAndFuseResult {
/// generated operation.
llvm::SetVector<Operation *> tiledAndFusedOps;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<scf::ForOp> loops;
SmallVector<Operation *> loops;
/// The replacement values to use for the tiled and fused operations.
llvm::DenseMap<Value, Value> replacements;
};
Expand Down
14 changes: 5 additions & 9 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,16 +434,12 @@ static LogicalResult applyTilingToAll(
SmallVector<Operation *> opsToReplace{target};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
SmallVector<Value> replacements;
replacements.reserve(toReplace->getNumResults());
for (OpResult res : toReplace->getResults()) {
auto it = tiledResults->replacements.find(res);
if (it == tiledResults->replacements.end())
replacements.push_back(res);
else
replacements.push_back(it->getSecond());
for (OpResult res : toReplace->getResults())
if (auto replacement = tiledResults->replacements.lookup(res))
rewriter.replaceAllUsesWith(res, replacement);
if (toReplace->use_empty()) {
rewriter.eraseOp(toReplace);
}
rewriter.replaceOp(toReplace, replacements);
}

// Report back the relevant handles to the transform op.
Expand Down
112 changes: 70 additions & 42 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
return filledVector;
}

/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
template <typename SrcOpTy>
static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
return llvm::to_vector(
llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
}
template <typename SrcOpTy>
static SmallVector<Operation *>
getAsOperations(const SmallVector<SrcOpTy> &ops) {
return getAsOperations(ArrayRef<SrcOpTy>(ops));
}

/// Convert a list of `Operation *` to a list of `DstOpTy.
template <typename DstOpTy>
static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
MaheshRavishankar marked this conversation as resolved.
Show resolved Hide resolved
return llvm::to_vector(
llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
}
template <typename DstOpTy>
static SmallVector<DstOpTy>
castToTypedOperations(const SmallVector<Operation *> &ops) {
return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
}

//===----------------------------------------------------------------------===//
// tileUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
Expand All @@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
Range loopRange, Value iv,
Value tileSize) {
std::optional<int64_t> ts = getConstantIntValue(tileSize);
if (ts && ts.value() == 1)
return getAsOpFoldResult(tileSize);
OpFoldResult tileSize) {
if (isConstantIntValue(tileSize, 1))
return tileSize;

if (tileDividesIterationDomain(
Range{loopRange.offset, loopRange.size, tileSize}))
Expand Down Expand Up @@ -295,8 +318,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
}

scf::SCFTilingResult tilingResult;
SmallVector<OpFoldResult> offsets, sizes;
SmallVector<scf::ForOp> forLoops;
{
// If there is an interchange specified, permute the iteration domain and
// the tile sizes.
Expand All @@ -319,8 +342,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// 3. Materialize an empty loop nest that iterates over the tiles. These
// loops for now do not return any values even if the original operation has
// results.
tilingResult.loops = generateTileLoopNest(
rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
tileSizeVector, offsets, sizes);

if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
Expand All @@ -330,30 +353,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
}

LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
if (!forLoops.empty()) {
llvm::dbgs() << "LoopNest shell :\n";
tilingResult.loops.front().dump();
forLoops.front().dump();
llvm::dbgs() << "\n";
}
});

// 4. Generate the tiled implementation within the inner most loop.
if (!tilingResult.loops.empty())
rewriter.setInsertionPoint(
tilingResult.loops.back().getBody()->getTerminator());
if (!forLoops.empty())
rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
FailureOr<TilingResult> tiledImplementation =
op.getTiledImplementation(rewriter, offsets, sizes);
tilingResult.tiledOps.append(tiledImplementation->tiledOps);

if (op->getNumResults() == 0) {
// nothing more to do.
return tilingResult;
return scf::SCFTilingResult{
tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
}

// If loops are empty, the tiled op is used as the replacement for the untiled
// op.
if (tilingResult.loops.empty()) {
tilingResult.replacements = tiledImplementation->tiledValues;
return tilingResult;
if (forLoops.empty()) {
return scf::SCFTilingResult{tiledImplementation->tiledOps,
getAsOperations(forLoops),
tiledImplementation->tiledValues};
}

// 5. Yield all the results of the tiled operation. The surrounding loop
Expand All @@ -377,18 +400,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
destinationTensors)))
return rewriter.notifyMatchFailure(op, "failed to get destinations");

tilingResult.replacements = yieldTiledValues(
SmallVector<Value> replacements = yieldTiledValues(
rewriter, destinationTensors, tiledImplementation.value(),
resultOffsetsList, resultSizesList, tilingResult.loops);

resultOffsetsList, resultSizesList, forLoops);
LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
if (!forLoops.empty()) {
llvm::dbgs() << "After tiled implementation :\n";
tilingResult.loops.front().dump();
forLoops.front().dump();
llvm::dbgs() << "\n";
}
});
return tilingResult;
return scf::SCFTilingResult{tiledImplementation->tiledOps,
getAsOperations(forLoops), replacements};
}

FailureOr<scf::SCFReductionTilingResult>
Expand Down Expand Up @@ -466,6 +489,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
results.mergeOp = mergeOp;
return results;
}

//===----------------------------------------------------------------------===//
// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -636,28 +660,31 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
}

// 1. First tile the consumer.
scf::SCFTileAndFuseResult tileAndFuseResult;
SmallVector<scf::ForOp> forLoops;
SetVector<Operation *> fusedProducers, tiledAndFusedOps;
DenseMap<Value, Value> replacements;
llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
{
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
if (failed(tilingResult))
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
for (auto *tiledOp : tilingResult->tiledOps)
tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
tileAndFuseResult.loops = std::move(tilingResult->loops);
for (const auto &result : llvm::enumerate(
llvm::zip(consumer->getResults(), tilingResult->replacements))) {
tileAndFuseResult.replacements[std::get<0>(result.value())] =
std::get<1>(result.value());
tiledAndFusedOps.insert(tiledOp);
forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
for (auto [index, origValue, replacement] :
llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
replacements[origValue] = replacement;
yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
result.index())] = result.index();
index)] = index;
}
}

// If there are no loops generated, fusion is immaterial.
if (tileAndFuseResult.loops.empty())
return tileAndFuseResult;
if (forLoops.empty()) {
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
getAsOperations(forLoops), replacements};
}

// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
Expand All @@ -674,7 +701,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
};

std::deque<tensor::ExtractSliceOp> candidates;
addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
addCandidateSlices(tiledAndFusedOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty()) {
// Traverse the slices in BFS fashion.
Expand All @@ -684,19 +711,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
tileAndFuseResult.loops);
if (!fusedProducer)
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
if (!fusedResult)
continue;

if (Operation *tiledAndFusedOp =
fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
tiledAndFusedOps.insert(tiledAndFusedOp);
addCandidateSlices(tiledAndFusedOp, candidates);
}
}
return tileAndFuseResult;
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
getAsOperations(forLoops), replacements};
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"}
%gemm = linalg.matmul {__internal_transform__ = "fusion"}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
return %gemm : tensor<?x?xf32>
Expand Down Expand Up @@ -47,7 +47,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%generic = linalg.generic {
__internal_linalg_transform__ = "fusion",
__internal_transform__ = "fusion",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
Expand Down Expand Up @@ -97,7 +97,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %r
%d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
%init1 = tensor.empty(%d0, %d2) : tensor<?x?xf32>
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"}
%gemm1 = linalg.matmul {__internal_transform__ = "gemm_fusion"}
ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %gemm1 : tensor<?x?xf32>
}
Expand Down Expand Up @@ -147,7 +147,7 @@ func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%init1 = tensor.empty(%d1, %d0) : tensor<?x?xf32>
%transpose = linalg.generic {
__internal_linalg_transform__ = "fusion",
__internal_transform__ = "fusion",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
iterator_types = ["parallel", "parallel"]}
ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
Expand Down Expand Up @@ -198,7 +198,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = linalg.generic {
__internal_linalg_transform__ = "gemm_interchange_fusion",
__internal_transform__ = "gemm_interchange_fusion",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%2 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
Expand Down Expand Up @@ -249,7 +249,7 @@ func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"],
__internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
__internal_transform__ = "gemm_plus_gemm_fusion"}
ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%5 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
Expand Down Expand Up @@ -302,7 +302,7 @@ func.func @matmul_plus_transpose_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x
affine_map<(d0, d1) -> (d1, d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"],
__internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
__internal_transform__ = "gemm_plus_gemm_fusion"}
ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%5 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
Expand Down Expand Up @@ -352,7 +352,7 @@ func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
%1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
%2 = linalg.matmul
{__internal_linalg_transform__ = "gemm_sequence_fusion"}
{__internal_transform__ = "gemm_sequence_fusion"}
ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
return %2 : tensor<?x?xf32>
Expand Down Expand Up @@ -425,7 +425,7 @@ func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
linalg.yield %10, %9 : f32, f32
} -> (tensor<30xf32>, tensor<30x3xf32>)
%6 = linalg.generic {
__internal_linalg_transform__ = "reduction_sequence_fusion",
__internal_transform__ = "reduction_sequence_fusion",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?
ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
%d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_sequence_fusion_and_yield"}
%gemm1 = linalg.matmul {__internal_transform__ = "gemm_sequence_fusion_and_yield"}
ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
}
Expand Down
Loading