Skip to content

Commit

Permalink
[mlir][TilingInterface] NFC code changes separated out from introduct…
Browse files Browse the repository at this point in the history
…ion of `scf::tileUsingSCFForallop`. (llvm#67081)

This patch contains NFC changes that are precursor to the introduction
of `scf::tileUsingSCFForallOp` method introduced in
llvm#67083.
  • Loading branch information
MaheshRavishankar authored and legrosbuffle committed Sep 29, 2023
1 parent f4179d1 commit 3aeef0e
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 127 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct SCFTilingResult {
/// of the last op.
SmallVector<Operation *> tiledOps;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<scf::ForOp> loops;
SmallVector<Operation *> loops;
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
Expand Down Expand Up @@ -160,7 +160,7 @@ struct SCFTileAndFuseResult {
/// generated operation.
llvm::SetVector<Operation *> tiledAndFusedOps;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<scf::ForOp> loops;
SmallVector<Operation *> loops;
/// The replacement values to use for the tiled and fused operations.
llvm::DenseMap<Value, Value> replacements;
};
Expand Down
14 changes: 5 additions & 9 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,16 +434,12 @@ static LogicalResult applyTilingToAll(
SmallVector<Operation *> opsToReplace{target};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
SmallVector<Value> replacements;
replacements.reserve(toReplace->getNumResults());
for (OpResult res : toReplace->getResults()) {
auto it = tiledResults->replacements.find(res);
if (it == tiledResults->replacements.end())
replacements.push_back(res);
else
replacements.push_back(it->getSecond());
for (OpResult res : toReplace->getResults())
if (auto replacement = tiledResults->replacements.lookup(res))
rewriter.replaceAllUsesWith(res, replacement);
if (toReplace->use_empty()) {
rewriter.eraseOp(toReplace);
}
rewriter.replaceOp(toReplace, replacements);
}

// Report back the relevant handles to the transform op.
Expand Down
112 changes: 70 additions & 42 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
return filledVector;
}

/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
template <typename SrcOpTy>
static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
return llvm::to_vector(
llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
}
template <typename SrcOpTy>
static SmallVector<Operation *>
getAsOperations(const SmallVector<SrcOpTy> &ops) {
return getAsOperations(ArrayRef<SrcOpTy>(ops));
}

/// Convert a list of `Operation *` to a list of `DstOpTy.
template <typename DstOpTy>
static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
return llvm::to_vector(
llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
}
template <typename DstOpTy>
static SmallVector<DstOpTy>
castToTypedOperations(const SmallVector<Operation *> &ops) {
return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
}

//===----------------------------------------------------------------------===//
// tileUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
Expand All @@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
Range loopRange, Value iv,
Value tileSize) {
std::optional<int64_t> ts = getConstantIntValue(tileSize);
if (ts && ts.value() == 1)
return getAsOpFoldResult(tileSize);
OpFoldResult tileSize) {
if (isConstantIntValue(tileSize, 1))
return tileSize;

if (tileDividesIterationDomain(
Range{loopRange.offset, loopRange.size, tileSize}))
Expand Down Expand Up @@ -296,8 +319,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
}

scf::SCFTilingResult tilingResult;
SmallVector<OpFoldResult> offsets, sizes;
SmallVector<scf::ForOp> forLoops;
{
// If there is an interchange specified, permute the iteration domain and
// the tile sizes.
Expand All @@ -320,8 +343,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// 3. Materialize an empty loop nest that iterates over the tiles. These
// loops for now do not return any values even if the original operation has
// results.
tilingResult.loops = generateTileLoopNest(
rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
tileSizeVector, offsets, sizes);

if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
Expand All @@ -331,30 +354,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
}

LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
if (!forLoops.empty()) {
llvm::dbgs() << "LoopNest shell :\n";
tilingResult.loops.front().dump();
forLoops.front().dump();
llvm::dbgs() << "\n";
}
});

// 4. Generate the tiled implementation within the inner most loop.
if (!tilingResult.loops.empty())
rewriter.setInsertionPoint(
tilingResult.loops.back().getBody()->getTerminator());
if (!forLoops.empty())
rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
FailureOr<TilingResult> tiledImplementation =
op.getTiledImplementation(rewriter, offsets, sizes);
tilingResult.tiledOps.append(tiledImplementation->tiledOps);

if (op->getNumResults() == 0) {
// nothing more to do.
return tilingResult;
return scf::SCFTilingResult{
tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
}

// If loops are empty, the tiled op is used as the replacement for the untiled
// op.
if (tilingResult.loops.empty()) {
tilingResult.replacements = tiledImplementation->tiledValues;
return tilingResult;
if (forLoops.empty()) {
return scf::SCFTilingResult{tiledImplementation->tiledOps,
getAsOperations(forLoops),
tiledImplementation->tiledValues};
}

// 5. Yield all the results of the tiled operation. The surrounding loop
Expand All @@ -378,18 +401,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
destinationTensors)))
return rewriter.notifyMatchFailure(op, "failed to get destinations");

tilingResult.replacements = yieldTiledValues(
SmallVector<Value> replacements = yieldTiledValues(
rewriter, destinationTensors, tiledImplementation.value(),
resultOffsetsList, resultSizesList, tilingResult.loops);

resultOffsetsList, resultSizesList, forLoops);
LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
if (!forLoops.empty()) {
llvm::dbgs() << "After tiled implementation :\n";
tilingResult.loops.front().dump();
forLoops.front().dump();
llvm::dbgs() << "\n";
}
});
return tilingResult;
return scf::SCFTilingResult{tiledImplementation->tiledOps,
getAsOperations(forLoops), replacements};
}

FailureOr<scf::SCFReductionTilingResult>
Expand Down Expand Up @@ -467,6 +490,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
results.mergeOp = mergeOp;
return results;
}

//===----------------------------------------------------------------------===//
// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -637,28 +661,31 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
}

// 1. First tile the consumer.
scf::SCFTileAndFuseResult tileAndFuseResult;
SmallVector<scf::ForOp> forLoops;
SetVector<Operation *> fusedProducers, tiledAndFusedOps;
DenseMap<Value, Value> replacements;
llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
{
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
if (failed(tilingResult))
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
for (auto *tiledOp : tilingResult->tiledOps)
tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
tileAndFuseResult.loops = std::move(tilingResult->loops);
for (const auto &result : llvm::enumerate(
llvm::zip(consumer->getResults(), tilingResult->replacements))) {
tileAndFuseResult.replacements[std::get<0>(result.value())] =
std::get<1>(result.value());
tiledAndFusedOps.insert(tiledOp);
forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
for (auto [index, origValue, replacement] :
llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
replacements[origValue] = replacement;
yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
result.index())] = result.index();
index)] = index;
}
}

// If there are no loops generated, fusion is immaterial.
if (tileAndFuseResult.loops.empty())
return tileAndFuseResult;
if (forLoops.empty()) {
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
getAsOperations(forLoops), replacements};
}

// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
Expand All @@ -675,7 +702,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
};

std::deque<tensor::ExtractSliceOp> candidates;
addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
addCandidateSlices(tiledAndFusedOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty()) {
// Traverse the slices in BFS fashion.
Expand All @@ -685,19 +712,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
tileAndFuseResult.loops);
if (!fusedProducer)
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
if (!fusedResult)
continue;

if (Operation *tiledAndFusedOp =
fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
tiledAndFusedOps.insert(tiledAndFusedOp);
addCandidateSlices(tiledAndFusedOp, candidates);
}
}
return tileAndFuseResult;
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
getAsOperations(forLoops), replacements};
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"}
%gemm = linalg.matmul {__internal_transform__ = "fusion"}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
return %gemm : tensor<?x?xf32>
Expand Down Expand Up @@ -47,7 +47,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%generic = linalg.generic {
__internal_linalg_transform__ = "fusion",
__internal_transform__ = "fusion",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
Expand Down Expand Up @@ -97,7 +97,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %r
%d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
%init1 = tensor.empty(%d0, %d2) : tensor<?x?xf32>
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"}
%gemm1 = linalg.matmul {__internal_transform__ = "gemm_fusion"}
ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %gemm1 : tensor<?x?xf32>
}
Expand Down Expand Up @@ -147,7 +147,7 @@ func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%init1 = tensor.empty(%d1, %d0) : tensor<?x?xf32>
%transpose = linalg.generic {
__internal_linalg_transform__ = "fusion",
__internal_transform__ = "fusion",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
iterator_types = ["parallel", "parallel"]}
ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
Expand Down Expand Up @@ -198,7 +198,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = linalg.generic {
__internal_linalg_transform__ = "gemm_interchange_fusion",
__internal_transform__ = "gemm_interchange_fusion",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%2 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
Expand Down Expand Up @@ -249,7 +249,7 @@ func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"],
__internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
__internal_transform__ = "gemm_plus_gemm_fusion"}
ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%5 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
Expand Down Expand Up @@ -302,7 +302,7 @@ func.func @matmul_plus_transpose_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x
affine_map<(d0, d1) -> (d1, d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"],
__internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
__internal_transform__ = "gemm_plus_gemm_fusion"}
ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%5 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
Expand Down Expand Up @@ -352,7 +352,7 @@ func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
%1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
%2 = linalg.matmul
{__internal_linalg_transform__ = "gemm_sequence_fusion"}
{__internal_transform__ = "gemm_sequence_fusion"}
ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
return %2 : tensor<?x?xf32>
Expand Down Expand Up @@ -425,7 +425,7 @@ func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
linalg.yield %10, %9 : f32, f32
} -> (tensor<30xf32>, tensor<30x3xf32>)
%6 = linalg.generic {
__internal_linalg_transform__ = "reduction_sequence_fusion",
__internal_transform__ = "reduction_sequence_fusion",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?
ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
%d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_sequence_fusion_and_yield"}
%gemm1 = linalg.matmul {__internal_transform__ = "gemm_sequence_fusion_and_yield"}
ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
}
Expand Down
Loading

0 comments on commit 3aeef0e

Please sign in to comment.