Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
MaheshRavishankar committed Oct 20, 2023
1 parent d1f1103 commit 55f9518
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 27 deletions.
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ struct SCFTilingOptions {
/// `scf.for`)
SmallVector<Attribute> mappingVector = {};
SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
mappingVector = llvm::to_vector(
llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; }));
mappingVector = llvm::map_to_vector(
mapping, [](auto attr) -> Attribute { return attr; });
return *this;
}
};
Expand Down Expand Up @@ -93,7 +93,7 @@ struct SCFTileAndFuseOptions {
}
};

/// Method to tile and op that implements the `TilingInterface` using
/// Method to tile an op that implements the `TilingInterface` using
/// `scf.forall`.
FailureOr<SCFTilingResult>
tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
Expand Down
27 changes: 12 additions & 15 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,

// 3. Build the offsets, sizes and steps for the tile and distributed loops.
SmallVector<OpFoldResult> lbs, ubs, steps;
for (auto [index, tileSize, loopRange] :
llvm::enumerate(tileSizeVector, loopRanges)) {
for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
if (isConstantIntValue(tileSize, 0))
continue;
lbs.push_back(loopRange.offset);
Expand All @@ -781,7 +780,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
return op->emitOpError("failed to get destination tensors");

// 5. Build the device mapping attribute;
// 5. Build the device mapping attribute.
std::optional<ArrayAttr> mappingAttr;
if (!options.mappingVector.empty()) {
mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
Expand All @@ -796,13 +795,10 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
// 7. Get the tile offset and sizes.
rewriter.setInsertionPoint(forallOp.getTerminator());
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
tiledOffsets.reserve(loopRanges.size());
tiledSizes.reserve(loopRanges.size());
ValueRange ivs = forallOp.getInductionVars();
{
int materializedLoopNum = 0;
for (auto [index, tileSize, loopRange] :
llvm::enumerate(tileSizeVector, loopRanges)) {
for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
if (isConstantIntValue(tileSize, 0)) {
tiledOffsets.push_back(loopRange.offset);
tiledSizes.push_back(loopRange.size);
Expand All @@ -816,15 +812,15 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
}

// 8. Tile the operation. Clone the operation to allow fix up of destination
// operands
// operands.
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
Operation *clonedOp =
cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
FailureOr<TilingResult> tilingResult =
cast<TilingInterface>(clonedOp).getTiledImplementation(
rewriter, tiledOffsets, tiledSizes);
if (failed(tilingResult))
return clonedOp->emitError("Failed to tile op: ");
return clonedOp->emitError("failed to tile op: ");
rewriter.eraseOp(clonedOp);

// 9. Parallel insert back into the result tensor.
Expand All @@ -836,24 +832,25 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
tiledSizes, resultOffsets,
resultSizes)))
resultSizes))) {
return op->emitOpError("output offsets couldn't be calculated");
}

SmallVector<OpFoldResult> strides(resultSizes.size(),
rewriter.getIndexAttr(1));

// 5.b. Parallel insertions are inserted at the end of the combining
// 9.b. Parallel insertions are inserted at the end of the combining
// terminator.
rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
rewriter.create<tensor::ParallelInsertSliceOp>(
loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
}

// 10. Return the tiling result;
// 10. Return the tiling result.
return scf::SCFTilingResult{
tilingResult->tiledOps,
{forallOp.getOperation()},
llvm::to_vector(llvm::map_range(forallOp.getResults(),
[](auto val) -> Value { return val; }))};
llvm::map_to_vector(forallOp.getResults(),
[](auto val) -> Value { return val; })};
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
// CHECK: mapping = [#gpu.block<y>, #gpu.block<x>]
// CHECK: return %[[RESULT]]

// -----
Expand Down
23 changes: 14 additions & 9 deletions mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
Expand Down Expand Up @@ -443,9 +444,9 @@ struct TestTilingInterfacePass
TestTilingInterfacePass(const TestTilingInterfacePass &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, linalg::LinalgDialect,
memref::MemRefDialect, scf::SCFDialect,
tensor::TensorDialect>();
registry.insert<affine::AffineDialect, gpu::GPUDialect,
linalg::LinalgDialect, memref::MemRefDialect,
scf::SCFDialect, tensor::TensorDialect>();
linalg::registerTilingInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
}
Expand Down Expand Up @@ -506,15 +507,16 @@ static void addPatternForTiling(MLIRContext *context,
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
}

static void addPatternForTilingUsingForall(MLIRContext *context,
RewritePatternSet &patterns,
StringRef filterName,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> interchange = {}) {
static void addPatternForTilingUsingForall(
MLIRContext *context, RewritePatternSet &patterns, StringRef filterName,
ArrayRef<int64_t> tileSizes,
ArrayRef<DeviceMappingAttrInterface> mapping = {},
ArrayRef<int64_t> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(context, tileSizes);
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
tilingOptions.setMapping(mapping);
TransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileUsingSCFForallOp>(context, tilingOptions, filter);
Expand Down Expand Up @@ -581,7 +583,10 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
}
if (testTilingForAll) {
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20});
addPatternForTilingUsingForall(
context, patterns, "simple_gemm", {10, 20},
{gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimY),
gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimX)});
// 2. Tiling 3D parallel generic op which implements a transpose.
addPatternForTilingUsingForall(context, patterns,
"parallel_generic_transpose", {10, 0, 20});
Expand Down

0 comments on commit 55f9518

Please sign in to comment.