-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate scf::forall
.
#67083
[mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate scf::forall
.
#67083
Conversation
Only the second commit is really part of this PR. The first commit is part of #67081 |
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-scf ChangesSimilar to Patch is 38.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67083.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index ca641c596c7b7bb..06cce19894e9f5a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -51,6 +51,17 @@ struct SCFTilingOptions {
interchangeVector = llvm::to_vector(interchange);
return *this;
}
+
+ /// Specify mapping of loops to devices. This is only respected when the loop
+ /// constructs support such a mapping (like `scf.forall`). Will be ignored
+ /// when using loop constructs that dont support such a mapping (like
+ /// `scf.for`)
+ SmallVector<Attribute> mappingVector = {};
+ SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
+ mappingVector = llvm::to_vector(
+ llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; }));
+ return *this;
+ }
};
/// Transformation information returned after tiling.
@@ -60,7 +71,7 @@ struct SCFTilingResult {
/// of the last op.
SmallVector<Operation *> tiledOps;
/// The `scf.for` operations that iterate over the tiles.
- SmallVector<scf::ForOp> loops;
+ SmallVector<Operation *> loops;
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
@@ -82,6 +93,12 @@ struct SCFTileAndFuseOptions {
}
};
+/// Method to tile and op that implements the `TilingInterface` using
+/// `scf.forall`.
+FailureOr<SCFTilingResult>
+tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+ const SCFTilingOptions &options);
+
/// Fuse the producer of the source of `candidateSliceOp` by computing the
/// required slice of the producer in-place. Note that the method
/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
@@ -160,7 +177,7 @@ struct SCFTileAndFuseResult {
/// generated operation.
llvm::SetVector<Operation *> tiledAndFusedOps;
/// The `scf.for` operations that iterate over the tiles.
- SmallVector<scf::ForOp> loops;
+ SmallVector<Operation *> loops;
/// The replacement values to use for the tiled and fused operations.
llvm::DenseMap<Value, Value> replacements;
};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1819ca614a060fd..ca3db7401e38caa 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -434,16 +434,12 @@ static LogicalResult applyTilingToAll(
SmallVector<Operation *> opsToReplace{target};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
- SmallVector<Value> replacements;
- replacements.reserve(toReplace->getNumResults());
- for (OpResult res : toReplace->getResults()) {
- auto it = tiledResults->replacements.find(res);
- if (it == tiledResults->replacements.end())
- replacements.push_back(res);
- else
- replacements.push_back(it->getSecond());
+ for (OpResult res : toReplace->getResults())
+ if (auto replacement = tiledResults->replacements.lookup(res))
+ rewriter.replaceAllUsesWith(res, replacement);
+ if (toReplace->use_empty()) {
+ rewriter.eraseOp(toReplace);
}
- rewriter.replaceOp(toReplace, replacements);
}
// Report back the relevant handles to the transform op.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 6cfba3fef15ebda..9054f7bcdde7e15 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
return filledVector;
}
+/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
+template <typename SrcOpTy>
+static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
+ return llvm::to_vector(
+ llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
+}
+template <typename SrcOpTy>
+static SmallVector<Operation *>
+getAsOperations(const SmallVector<SrcOpTy> &ops) {
+ return getAsOperations(ArrayRef<SrcOpTy>(ops));
+}
+
+/// Convert a list of `Operation *` to a list of `DstOpTy`
+template <typename DstOpTy>
+static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
+ return llvm::to_vector(
+ llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
+}
+template <typename DstOpTy>
+static SmallVector<DstOpTy>
+castToTypedOperations(const SmallVector<Operation *> &ops) {
+ return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
+}
+
//===----------------------------------------------------------------------===//
// tileUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
@@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
Range loopRange, Value iv,
- Value tileSize) {
- std::optional<int64_t> ts = getConstantIntValue(tileSize);
- if (ts && ts.value() == 1)
- return getAsOpFoldResult(tileSize);
+ OpFoldResult tileSize) {
+ if (isConstantIntValue(tileSize, 1))
+ return tileSize;
if (tileDividesIterationDomain(
Range{loopRange.offset, loopRange.size, tileSize}))
@@ -98,6 +121,24 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
}
+/// Clones the operation and updates the destination if the operation
+/// implements the `DestinationStyleOpInterface`.
+static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
+ Operation *op,
+ ValueRange newDestArgs) {
+ Operation *clonedOp = rewriter.clone(*op);
+ if (auto destinationStyleOp =
+ dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
+ // Note that this is assuming that
+ auto [start, end] = destinationStyleOp.getDpsInitsPositionRange();
+ assert((end - start == newDestArgs.size()) &&
+ "expected as many new destination args as number of inits of the "
+ "operation");
+ clonedOp->setOperands(start, end - start, newDestArgs);
+ }
+ return clonedOp;
+}
+
/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
@@ -295,8 +336,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
}
- scf::SCFTilingResult tilingResult;
SmallVector<OpFoldResult> offsets, sizes;
+ SmallVector<scf::ForOp> forLoops;
{
// If there is an interchange specified, permute the iteration domain and
// the tile sizes.
@@ -319,8 +360,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// 3. Materialize an empty loop nest that iterates over the tiles. These
// loops for now do not return any values even if the original operation has
// results.
- tilingResult.loops = generateTileLoopNest(
- rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
+ forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
+ tileSizeVector, offsets, sizes);
if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
@@ -330,30 +371,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
}
LLVM_DEBUG({
- if (!tilingResult.loops.empty()) {
+ if (!forLoops.empty()) {
llvm::dbgs() << "LoopNest shell :\n";
- tilingResult.loops.front().dump();
+ forLoops.front().dump();
llvm::dbgs() << "\n";
}
});
// 4. Generate the tiled implementation within the inner most loop.
- if (!tilingResult.loops.empty())
- rewriter.setInsertionPoint(
- tilingResult.loops.back().getBody()->getTerminator());
+ if (!forLoops.empty())
+ rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
FailureOr<TilingResult> tiledImplementation =
op.getTiledImplementation(rewriter, offsets, sizes);
- tilingResult.tiledOps.append(tiledImplementation->tiledOps);
+
if (op->getNumResults() == 0) {
- // nothing more to do.
- return tilingResult;
+ return scf::SCFTilingResult{
+ tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
}
// If loops are empty, the tiled op is used as the replacement for the untiled
// op.
- if (tilingResult.loops.empty()) {
- tilingResult.replacements = tiledImplementation->tiledValues;
- return tilingResult;
+ if (forLoops.empty()) {
+ return scf::SCFTilingResult{tiledImplementation->tiledOps,
+ getAsOperations(forLoops),
+ tiledImplementation->tiledValues};
}
// 5. Yield all the results of the tiled operation. The surrounding loop
@@ -377,18 +418,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
destinationTensors)))
return rewriter.notifyMatchFailure(op, "failed to get destinations");
- tilingResult.replacements = yieldTiledValues(
+ SmallVector<Value> replacements = yieldTiledValues(
rewriter, destinationTensors, tiledImplementation.value(),
- resultOffsetsList, resultSizesList, tilingResult.loops);
-
+ resultOffsetsList, resultSizesList, forLoops);
LLVM_DEBUG({
- if (!tilingResult.loops.empty()) {
+ if (!forLoops.empty()) {
llvm::dbgs() << "After tiled implementation :\n";
- tilingResult.loops.front().dump();
+ forLoops.front().dump();
llvm::dbgs() << "\n";
}
});
- return tilingResult;
+ return scf::SCFTilingResult{tiledImplementation->tiledOps,
+ getAsOperations(forLoops), replacements};
}
FailureOr<scf::SCFReductionTilingResult>
@@ -466,6 +507,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
results.mergeOp = mergeOp;
return results;
}
+
//===----------------------------------------------------------------------===//
// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
@@ -636,7 +678,9 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
}
// 1. First tile the consumer.
- scf::SCFTileAndFuseResult tileAndFuseResult;
+ SmallVector<scf::ForOp> forLoops;
+ SetVector<Operation *> fusedProducers, tiledAndFusedOps;
+ DenseMap<Value, Value> replacements;
llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
{
FailureOr<scf::SCFTilingResult> tilingResult =
@@ -644,20 +688,21 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
if (failed(tilingResult))
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
for (auto *tiledOp : tilingResult->tiledOps)
- tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
- tileAndFuseResult.loops = std::move(tilingResult->loops);
- for (const auto &result : llvm::enumerate(
- llvm::zip(consumer->getResults(), tilingResult->replacements))) {
- tileAndFuseResult.replacements[std::get<0>(result.value())] =
- std::get<1>(result.value());
+ tiledAndFusedOps.insert(tiledOp);
+ forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
+ for (auto [index, origValue, replacement] :
+ llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
+ replacements[origValue] = replacement;
yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
- result.index())] = result.index();
+ index)] = index;
}
}
// If there are no loops generated, fusion is immaterial.
- if (tileAndFuseResult.loops.empty())
- return tileAndFuseResult;
+ if (forLoops.empty()) {
+ return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
+ getAsOperations(forLoops), replacements};
+ }
// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
@@ -674,7 +719,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
};
std::deque<tensor::ExtractSliceOp> candidates;
- addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
+ addCandidateSlices(tiledAndFusedOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty()) {
// Traverse the slices in BFS fashion.
@@ -684,19 +729,135 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
- std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
- tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
- tileAndFuseResult.loops);
- if (!fusedProducer)
+ std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
+ tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
+ if (!fusedResult)
continue;
if (Operation *tiledAndFusedOp =
- fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
- tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
+ fusedResult->tiledAndFusedProducer.getDefiningOp()) {
+ fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
+ tiledAndFusedOps.insert(tiledAndFusedOp);
addCandidateSlices(tiledAndFusedOp, candidates);
}
}
- return tileAndFuseResult;
+ return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
+ getAsOperations(forLoops), replacements};
+}
+
+//===----------------------------------------------------------------------===//
+// tileUsingSCFForAllOp implementation.
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::SCFTilingResult>
+mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+ const scf::SCFTilingOptions &options) {
+ Location loc = op->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+
+ // 1. Get the range of loops that are represented by the operation.
+ SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+ if (loopRanges.empty())
+ return op->emitOpError("expected non-empty loop ranges");
+ auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+ if (llvm::any_of(loopRanges, hasStrideOne))
+ return op->emitOpError("only stride-1 supported atm");
+
+ // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
+ // To make it easier, pad the tile sizes to loopRanges.size with value 0.
+ SmallVector<OpFoldResult> tileSizeVector =
+ options.tileSizeComputationFunction(rewriter, op);
+ tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
+
+ // 3. Build the offsets, sizes and steps for the tile and distributed loops.
+ SmallVector<OpFoldResult> lbs, ubs, steps;
+ for (auto [index, tileSize, loopRange] :
+ llvm::enumerate(tileSizeVector, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0))
+ continue;
+ lbs.push_back(loopRange.offset);
+ ubs.push_back(loopRange.size);
+ steps.push_back(tileSize);
+ }
+
+ // 4. Gather destination tensors.
+ SmallVector<Value> dest;
+ if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
+ return op->emitOpError("failed to get destination tensors");
+
+ // 5. Build the device mapping attribute;
+ std::optional<ArrayAttr> mappingAttr;
+ if (!options.mappingVector.empty()) {
+ mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
+ }
+
+ // 6. Create the ForallOp. We don't use the lambda body-builder
+ // version because we require the use of RewriterBase in the body, so we
+ // manually move the insertion point to the body below.
+ auto forallOp =
+ rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
+
+ // 7. Get the tile offset and sizes.
+ rewriter.setInsertionPoint(forallOp.getTerminator());
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ tiledOffsets.reserve(loopRanges.size());
+ tiledSizes.reserve(loopRanges.size());
+ ValueRange ivs = forallOp.getInductionVars();
+ {
+ int materializedLoopNum = 0;
+ for (auto [index, tileSize, loopRange] :
+ llvm::enumerate(tileSizeVector, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0)) {
+ tiledOffsets.push_back(loopRange.offset);
+ tiledSizes.push_back(loopRange.size);
+ continue;
+ }
+ Value iv = ivs[materializedLoopNum++];
+ tiledOffsets.push_back(iv);
+ tiledSizes.push_back(
+ getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+ }
+ }
+
+ // 8. Tile the operation. Clone the operation to allow fix up of destination
+ // operands
+ ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+ Operation *clonedOp =
+ cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
+ FailureOr<TilingResult> tilingResult =
+ cast<TilingInterface>(clonedOp).getTiledImplementation(
+ rewriter, tiledOffsets, tiledSizes);
+ if (failed(tilingResult))
+ return clonedOp->emitError("Failed to tile op: ");
+ rewriter.eraseOp(clonedOp);
+
+ // 9. Parallel insert back into the result tensor.
+ for (auto [index, tiledValue, destBBArg] :
+ llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
+ // 9.a. Partial subset information is inserted just before the terminator.
+ rewriter.setInsertionPoint(forallOp.getTerminator());
+
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
+ tiledSizes, resultOffsets,
+ resultSizes)))
+ return op->emitOpError("output offsets couldn't be calculated");
+ SmallVector<OpFoldResult> strides(resultSizes.size(),
+ rewriter.getIndexAttr(1));
+
+ // 5.b. Parallel insertions are inserted at the end of the combining
+ // terminator.
+ rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
+ }
+
+ // 10. Return the tiling result;
+ return scf::SCFTilingResult{
+ tilingResult->tiledOps,
+ {forallOp.getOperation()},
+ llvm::to_vector(llvm::map_range(forallOp.getResults(),
+ [](auto val) -> Value { return val; }))};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
new file mode 100644
index 000000000000000..f40374b7b5485da
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
+
+func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?...
[truncated]
|
@nicolasvasilache I adapted the implementation here from
|
Great to see this! See a rough sketch here of stuff I'd like to drop https://github.com/nicolasvasilache/llvm-project/tree/tiling-cleanups, do you think this is doable or is there something load-bearing here ?
They have different behaviors in cases that don't divide: num_threads guarantees a static number of threads and makes the tile sizes dynamic. If we only specified tile sizes and things don't divide we could end up with dynamic number of threads which has quite some implications later. |
Yes, I was happy to see this getting some attention. Will be good to put to bed. |
Thanks! yeah those should be retire-able in due course. Ill keep these in mind.
I am trying to basically see if I can standardize on the tile_size variant.. That is consistent with the tiling using scf.for... Maybe some more context would help. For example, |
…ion of `scf::tileUsingSCFForallop`. (llvm#67081) This patch contains NFC changes that are precursor to the introduction of `scf::tileUsingSCFForallOp` method introduced in llvm#67083.
1f2b48c
to
b26e643
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nicolasvasilache this is ready to land. Ill add more things to this like support for interchange, etc. and do more cleanup of the other tiling methods, but this should be ready to go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just some nits!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Looks good to me too.
…e using the interface to generate `scf::forall`. Similar to `scf::tileUsingSCFForOp` that is a method that tiles operations that implement the `TilingInterface`, using `scf.for` operations, this method introduces tiling of operations using `scf.forall`. Most of this implementation is derived from `linalg::tileToForallOp` method. Eventually that method will either be deprecated or moved to use the method introduced here.
b26e643
to
55f9518
Compare
I am confused, why isn't any transform op modified in this commit ? |
For now it is duplicate. I have to build these up incrementally to connect everything and deprecate the old Linalg tiling path. |
Ok .. what is your expected timeline on this ? |
I am working on it... I dont have a timeline though... trying to make progress on this as and when I get time. |
The current implementation of tiling using `scf.for` is convoluted to make sure that the destination passing style of the untiled program is preserved. The addition of support to tile using `scf.forall` (adapted from the transform operation in Linalg) in #67083 used cloning of the tiled operations to better streamline the implementation. This PR adapts the other tiling methods to use a similar approach, making the transformations (and handling destination passing style semantics) more systematic. --------- Co-authored-by: Abhishek-Varma <[email protected]>
Similar to
scf::tileUsingSCFForOp
that is a method that tilesoperations that implement the
TilingInterface
, usingscf.for
operations, this method introduces tiling of operations using
scf.forall
. Most of this implementation is derived fromlinalg::tileToForallOp
method. Eventually that method will either bedeprecated or moved to use the method introduced here.