Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
MaheshRavishankar committed Sep 25, 2023
1 parent a9e014d commit c02f458
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
12 changes: 9 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,15 @@ class TileConsumerAndFuseInputProducer final

// Fuse the candidate immeidate operands into the tiled loop.
OpBuilder::InsertionGuard guard(rewriter);
auto forLoops =
llvm::to_vector(llvm::map_range(tilingResult->loops, [](Operation *op) {
return cast<scf::ForOp>(op);
}));
while (!candidates.empty()) {
tensor::ExtractSliceOp sliceOp = candidates.back();
candidates.pop_back();
std::optional<scf::SCFFuseProducerOfSliceResult> result =
tileAndFuseProducerOfSlice(rewriter, sliceOp, tilingResult->loops);
tileAndFuseProducerOfSlice(rewriter, sliceOp, forLoops);
if (result) {
// Mark the fused input producer for distribution when writing to shared
// memory. We cannot use the current matmul op's tiling scheme here
Expand All @@ -156,6 +160,8 @@ class TileConsumerAndFuseInputProducer final
rewriter, result->tiledAndFusedProducer.getDefiningOp());
}
}
tilingResult->loops = llvm::to_vector(
llvm::map_range(forLoops, [](auto op) -> Operation * { return op; }));
return tilingResult;
}

Expand Down Expand Up @@ -304,10 +310,10 @@ static LogicalResult tileAndUnrollConv(func::FuncOp funcOp) {
// Fully unroll the generated loop. This allows us to remove the loop
// for parallel output window dimension, so it helps future vector
// transformations.
ArrayRef<scf::ForOp> loops = tileAndFuseResult.value().loops;
ArrayRef<Operation *> loops = tileAndFuseResult.value().loops;
if (!loops.empty()) {
assert(loops.size() == 1);
scf::ForOp loopOp = loops.front();
scf::ForOp loopOp = cast<scf::ForOp>(loops.front());
IntegerAttr ub;
if (!matchPattern(loopOp.getUpperBound(), m_Constant(&ub))) {
loopOp.emitOpError("upper bound should be a constant");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp,
if (failed(tilingResult)) {
return failure();
}
auto forLoops = llvm::to_vector(llvm::map_range(
tilingResult->loops, [](Operation *op) { return cast<scf::ForOp>(op); }));
yieldedValuesToOrigValues.append(rootOp->result_begin(),
rootOp->result_end());
// A map from untiled value to scf.for iter_arg. The iter_arg is used for DPS
Expand All @@ -129,9 +131,9 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp,
tilingResult->tiledOps[0] = replacementTiledOp.value();
}
} else if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(rootOp)) {
for (auto [init, iterArg] :
llvm::zip_equal(dpsOp.getDpsInitOperands(),
tilingResult->loops.back().getRegionIterArgs())) {
for (auto [init, iterArg] : llvm::zip_equal(
dpsOp.getDpsInitOperands(),
cast<scf::ForOp>(forLoops.back()).getRegionIterArgs())) {
mapToIterArg[init->get()] = iterArg;
}
}
Expand Down Expand Up @@ -174,20 +176,18 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp,

// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
tilingResult->loops);
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
if (!fusedProducer)
continue;

// Check if the fused producer has other uses that require the value
// to be yielded from within the tiled loop.
OpResult untiledProducer = fusedProducer->origProducer;
if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
return !isIgnoredUser(user, tilingResult->loops.front());
return !isIgnoredUser(user, forLoops.front());
})) {
yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
fusedProducer.value(),
tilingResult->loops);
fusedProducer.value(), forLoops);
yieldedValuesToOrigValues.push_back(untiledProducer);
}

Expand All @@ -198,7 +198,7 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp,
}
}

scf::ForOp outermostLoop = tilingResult->loops.front();
scf::ForOp outermostLoop = forLoops.front();
for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
Value replacement = outermostLoop.getResult(index);
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ static LogicalResult tileAndDistributeToThreads(linalg::LinalgOp consumerOp,
// We don't distribute here; instead, it will be done in a later step
// after bufferization. So add attributes to the tiled loop nest to
// indicate that they should be distributed to invocations.
ArrayRef<scf::ForOp> loops = tileAndFuseResult.value().loops;
ArrayRef<Operation *> loops = tileAndFuseResult.value().loops;
const char *attrName = getSPIRVDistributeAttrName();
// We can have more than 3 dimensions being tiled (e.g., for convolutions with
// non-1 batch). But only the innermost 3 dimensions are distributed.
Expand Down Expand Up @@ -273,10 +273,10 @@ static LogicalResult tileAndUnrollConvWindow(func::FuncOp funcOp,
// for parallel output window dimension, so it helps future vector
// transformations.

ArrayRef<scf::ForOp> loops = tileAndFuseResult.value().loops;
ArrayRef<Operation *> loops = tileAndFuseResult.value().loops;
if (!loops.empty()) {
assert(loops.size() == 1);
scf::ForOp loopOp = loops.front();
scf::ForOp loopOp = cast<scf::ForOp>(loops.front());
IntegerAttr ub;
if (!matchPattern(loopOp.getUpperBound(), m_Constant(&ub))) {
return loopOp.emitOpError("upper bound should be a constant");
Expand Down

0 comments on commit c02f458

Please sign in to comment.