Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate scf::forall. #67083

Merged
merged 3 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ struct SCFTilingOptions {
interchangeVector = llvm::to_vector(interchange);
return *this;
}

/// Specify mapping of loops to devices. This is only respected when the loop
/// constructs support such a mapping (like `scf.forall`). Will be ignored
/// when using loop constructs that dont support such a mapping (like
/// `scf.for`)
SmallVector<Attribute> mappingVector = {};
SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
mappingVector = llvm::map_to_vector(
mapping, [](auto attr) -> Attribute { return attr; });
return *this;
}
};

/// Transformation information returned after tiling.
Expand Down Expand Up @@ -82,6 +93,12 @@ struct SCFTileAndFuseOptions {
}
};

/// Method to tile an op that implements the `TilingInterface` using
/// `scf.forall`.
FailureOr<SCFTilingResult>
tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
const SCFTilingOptions &options);

/// Fuse the producer of the source of `candidateSliceOp` by computing the
/// required slice of the producer in-place. Note that the method
/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
Expand Down
129 changes: 127 additions & 2 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ static bool tileDividesIterationDomain(Range loopRange) {
/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
Range loopRange, Value iv,
Value tileSize) {
OpFoldResult tileSize) {
std::optional<int64_t> ts = getConstantIntValue(tileSize);
if (ts && ts.value() == 1)
return getAsOpFoldResult(tileSize);
return tileSize;

if (tileDividesIterationDomain(
Range{loopRange.offset, loopRange.size, tileSize}))
Expand All @@ -122,6 +122,19 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
}

/// Clones the operation and updates the destination if the operation
/// implements the `DestinationStyleOpInterface`.
static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
Operation *op,
ValueRange newDestArgs) {
Operation *clonedOp = rewriter.clone(*op);
if (auto destinationStyleOp =
dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
}
return clonedOp;
}

/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
Expand Down Expand Up @@ -728,6 +741,118 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
getAsOperations(forLoops), replacements};
}

//===----------------------------------------------------------------------===//
// tileUsingSCFForAllOp implementation.
//===----------------------------------------------------------------------===//

FailureOr<scf::SCFTilingResult>
mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
const scf::SCFTilingOptions &options) {
Location loc = op->getLoc();
OpBuilder::InsertionGuard g(rewriter);

// 1. Get the range of loops that are represented by the operation.
SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
if (loopRanges.empty())
return op->emitOpError("expected non-empty loop ranges");
auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
if (llvm::any_of(loopRanges, hasStrideOne))
return op->emitOpError("only stride-1 supported atm");

// 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
// To make it easier, pad the tile sizes to loopRanges.size with value 0.
SmallVector<OpFoldResult> tileSizeVector =
options.tileSizeComputationFunction(rewriter, op);
tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));

// 3. Build the offsets, sizes and steps for the tile and distributed loops.
SmallVector<OpFoldResult> lbs, ubs, steps;
for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
if (isConstantIntValue(tileSize, 0))
continue;
lbs.push_back(loopRange.offset);
ubs.push_back(loopRange.size);
steps.push_back(tileSize);
}

// 4. Gather destination tensors.
SmallVector<Value> dest;
if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
return op->emitOpError("failed to get destination tensors");

// 5. Build the device mapping attribute.
std::optional<ArrayAttr> mappingAttr;
if (!options.mappingVector.empty()) {
MaheshRavishankar marked this conversation as resolved.
Show resolved Hide resolved
mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
}

// 6. Create the ForallOp. We don't use the lambda body-builder
// version because we require the use of RewriterBase in the body, so we
// manually move the insertion point to the body below.
auto forallOp =
rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);

// 7. Get the tile offset and sizes.
rewriter.setInsertionPoint(forallOp.getTerminator());
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
ValueRange ivs = forallOp.getInductionVars();
{
int materializedLoopNum = 0;
for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
if (isConstantIntValue(tileSize, 0)) {
tiledOffsets.push_back(loopRange.offset);
tiledSizes.push_back(loopRange.size);
continue;
}
Value iv = ivs[materializedLoopNum++];
tiledOffsets.push_back(iv);
tiledSizes.push_back(
getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
}
}

// 8. Tile the operation. Clone the operation to allow fix up of destination
// operands.
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
Operation *clonedOp =
cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
FailureOr<TilingResult> tilingResult =
cast<TilingInterface>(clonedOp).getTiledImplementation(
rewriter, tiledOffsets, tiledSizes);
if (failed(tilingResult))
return clonedOp->emitError("failed to tile op: ");
rewriter.eraseOp(clonedOp);

// 9. Parallel insert back into the result tensor.
for (auto [index, tiledValue, destBBArg] :
llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
// 9.a. Partial subset information is inserted just before the terminator.
rewriter.setInsertionPoint(forallOp.getTerminator());

SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
MaheshRavishankar marked this conversation as resolved.
Show resolved Hide resolved
tiledSizes, resultOffsets,
resultSizes))) {
return op->emitOpError("output offsets couldn't be calculated");
}

SmallVector<OpFoldResult> strides(resultSizes.size(),
rewriter.getIndexAttr(1));
// 9.b. Parallel insertions are inserted at the end of the combining
// terminator.
rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
rewriter.create<tensor::ParallelInsertSliceOp>(
loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
}

// 10. Return the tiling result.
return scf::SCFTilingResult{
tilingResult->tiledOps,
{forallOp.getOperation()},
llvm::map_to_vector(forallOp.getResults(),
[](auto val) -> Value { return val; })};
}

//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
Expand Down
167 changes: 167 additions & 0 deletions mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s

func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul {__internal_transform__ = "simple_gemm"}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
// CHECK: func.func @simple_matmul(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
// CHECK-SAME: (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]])
// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[N]]]
// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]]
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
// CHECK-SAME: outs(%[[INIT_TILE]] :
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
// CHECK: mapping = [#gpu.block<y>, #gpu.block<x>]
// CHECK: return %[[RESULT]]

// -----

#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
%init0 = tensor.empty() : tensor<128x300x200xf32>
%init1 = tensor.empty() : tensor<300x128x200xf32>
%0:2 = linalg.generic {
indexing_maps = [#map0, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel"]}
{__internal_transform__ = "parallel_generic_transpose"}
ins(%arg0 : tensor<128x200x300xf32>)
outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
linalg.yield %b0, %b0 : f32, f32
} -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
}
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
// CHECK-LABEL: func.func @multi_result(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty()
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty()
// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = (0, 0) to (128, 300) step (10, 20)
// CHECK-SAME: shared_outs(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
// CHECK: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
// CHECK: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG2]]
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic
// CHECK-SAME: ins(%[[ARG_TILE]] :
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice %[[RESULT_TILE]]#0 into %[[ARG1]][%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
// CHECK-DAG: tensor.parallel_insert_slice %[[RESULT_TILE]]#1 into %[[ARG2]][%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
// CHECK: }
// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1

// -----

func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%0 = linalg.conv_2d_nhwc_hwcf {
strides = dense<[2, 3]> : tensor<2xi64>,
dilation = dense<[4, 5]> : tensor<2xi64>,
__internal_transform__ = "simple_conv"}
ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)>
// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)>
// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)>
// CHECK-LABEL: func.func @conv2D(
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[FILTER:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
// CHECK-DAG: %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
// CHECK-DAG: %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
// CHECK-DAG: %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]]
// CHECK-DAG: %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
// CHECK-DAG: %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
// CHECK-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]], %[[IV2:[a-zA-Z0-9]+]]) =
// CHECK-SAME: (0, 0, 0) to (%[[P]], %[[Q]], %[[C]]) step (10, 20, 30) shared_outs(%[[INIT0:.+]] = %[[INIT]])
// CHECK-DAG: %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]]
// CHECK-DAG: %[[TS_Q:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[Q]]]
// CHECK-DAG: %[[TS_C:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[C]]]
// CHECK-DAG: %[[TS_H:.+]] = affine.apply #[[$MAP3]](%[[TS_P]])[%[[R]]]
// CHECK-DAG: %[[TS_W:.+]] = affine.apply #[[$MAP4]](%[[TS_Q]])[%[[S]]]
// CHECK-DAG: %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]]
// CHECK-SAME: [0, %[[IV0]], %[[IV1]], %[[IV2]]] [%[[N]], %[[TS_H]], %[[TS_W]], %[[TS_C]]]
// CHECK-DAG: %[[FILTER_TILE:.+]] = tensor.extract_slice %[[FILTER]]
// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], 0] [%[[TS_P]], %[[TS_Q]], %[[TS_C]], %[[F]]]
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT0]]
// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
// CHECK: %[[CONV_TILE:.+]] = linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: dilation = dense<[4, 5]> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>
// CHECK-SAME: ins(%[[INPUT_TILE]], %[[FILTER_TILE]] :
// CHECK-SAME: outs(%[[INIT_TILE]] :
// CHECK: scf.forall.in_parallel
// CHECK: tensor.parallel_insert_slice %[[CONV_TILE]] into %[[INIT0]]
// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]] [1, 1, 1, 1]
// CHECK: return %[[RESULT]]

// -----

// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>

func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// Check that we correctly amend "linalg.index" results.

%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
{__internal_transform__ = "indexed_semantics"}
ins(%arg0: tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%1 = linalg.index 0 : index
%2 = linalg.index 1 : index
%3 = arith.addi %1, %2 : index
%4 = arith.index_cast %3 : index to i64
%5 = arith.uitofp %4 : i64 to f32
%6 = arith.addf %5, %arg2 : f32
linalg.yield %6 : f32
} -> (tensor<?x?xf32>)
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: @indexed_semantics
// CHECK: scf.forall (%[[I0:.+]], %[[I1:.+]]) =
// CHECK: %[[INDEX0:.+]] = linalg.index 0
// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
// CHECK: %[[INDEX1:.+]] = linalg.index 1
// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
Loading