From c02f458d8881bcfdfaa6fcd711fe342d4cea2063 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 21 Sep 2023 18:40:10 -0700 Subject: [PATCH] Fixes for https://github.com/llvm/llvm-project/pull/67081 --- .../Codegen/Common/GPU/GPUTensorTile.cpp | 12 +++++++++--- .../Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp | 18 +++++++++--------- .../iree/compiler/Codegen/SPIRV/SPIRVTile.cpp | 6 +++--- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp index c2664ff065441..d6b4003817477 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp @@ -140,11 +140,15 @@ class TileConsumerAndFuseInputProducer final // Fuse the candidate immeidate operands into the tiled loop. OpBuilder::InsertionGuard guard(rewriter); + auto forLoops = + llvm::to_vector(llvm::map_range(tilingResult->loops, [](Operation *op) { + return cast(op); + })); while (!candidates.empty()) { tensor::ExtractSliceOp sliceOp = candidates.back(); candidates.pop_back(); std::optional result = - tileAndFuseProducerOfSlice(rewriter, sliceOp, tilingResult->loops); + tileAndFuseProducerOfSlice(rewriter, sliceOp, forLoops); if (result) { // Mark the fused input producer for distribution when writing to shared // memory. We cannot use the current matmul op's tiling scheme here @@ -156,6 +160,8 @@ class TileConsumerAndFuseInputProducer final rewriter, result->tiledAndFusedProducer.getDefiningOp()); } } + tilingResult->loops = llvm::to_vector( + llvm::map_range(forLoops, [](auto op) -> Operation * { return op; })); return tilingResult; } @@ -304,10 +310,10 @@ static LogicalResult tileAndUnrollConv(func::FuncOp funcOp) { // Fully unroll the generated loop. This allows us to remove the loop // for parallel output window dimension, so it helps future vector // transformations. - ArrayRef loops = tileAndFuseResult.value().loops; + ArrayRef loops = tileAndFuseResult.value().loops; if (!loops.empty()) { assert(loops.size() == 1); - scf::ForOp loopOp = loops.front(); + scf::ForOp loopOp = cast(loops.front()); IntegerAttr ub; if (!matchPattern(loopOp.getUpperBound(), m_Constant(&ub))) { loopOp.emitOpError("upper bound should be a constant"); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp index 31f3faf45b196..0f9991322421a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp @@ -113,6 +113,8 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp, if (failed(tilingResult)) { return failure(); } + auto forLoops = llvm::to_vector(llvm::map_range( + tilingResult->loops, [](Operation *op) { return cast(op); })); yieldedValuesToOrigValues.append(rootOp->result_begin(), rootOp->result_end()); // A map from untiled value to scf.for iter_arg. The iter_arg is used for DPS @@ -129,9 +131,9 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp, tilingResult->tiledOps[0] = replacementTiledOp.value(); } } else if (auto dpsOp = dyn_cast(rootOp)) { - for (auto [init, iterArg] : - llvm::zip_equal(dpsOp.getDpsInitOperands(), - tilingResult->loops.back().getRegionIterArgs())) { + for (auto [init, iterArg] : llvm::zip_equal( + dpsOp.getDpsInitOperands(), + cast(forLoops.back()).getRegionIterArgs())) { mapToIterArg[init->get()] = iterArg; } } @@ -174,8 +176,7 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp, // Materialize the slice of the producer in place. std::optional fusedProducer = - tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, - tilingResult->loops); + tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops); if (!fusedProducer) continue; @@ -183,11 +184,10 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp, // to be yielded from within the tiled loop. OpResult untiledProducer = fusedProducer->origProducer; if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) { - return !isIgnoredUser(user, tilingResult->loops.front()); + return !isIgnoredUser(user, forLoops.front()); })) { yieldReplacementForFusedProducer(rewriter, candidateSliceOp, - fusedProducer.value(), - tilingResult->loops); + fusedProducer.value(), forLoops); yieldedValuesToOrigValues.push_back(untiledProducer); } @@ -198,7 +198,7 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp, } } - scf::ForOp outermostLoop = tilingResult->loops.front(); + scf::ForOp outermostLoop = forLoops.front(); for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) { Value replacement = outermostLoop.getResult(index); rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp index 13a7290f7b028..0f40efe8e40fb 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp @@ -138,7 +138,7 @@ static LogicalResult tileAndDistributeToThreads(linalg::LinalgOp consumerOp, // We don't distribute here; instead, it will be done in a later step // after bufferization. So add attributes to the tiled loop nest to // indicate that they should be distributed to invocations. - ArrayRef loops = tileAndFuseResult.value().loops; + ArrayRef loops = tileAndFuseResult.value().loops; const char *attrName = getSPIRVDistributeAttrName(); // We can have more than 3 dimensions being tiled (e.g., for convolutions with // non-1 batch). But only the innermost 3 dimensions are distributed. @@ -273,10 +273,10 @@ static LogicalResult tileAndUnrollConvWindow(func::FuncOp funcOp, // for parallel output window dimension, so it helps future vector // transformations. - ArrayRef loops = tileAndFuseResult.value().loops; + ArrayRef loops = tileAndFuseResult.value().loops; if (!loops.empty()) { assert(loops.size() == 1); - scf::ForOp loopOp = loops.front(); + scf::ForOp loopOp = cast(loops.front()); IntegerAttr ub; if (!matchPattern(loopOp.getUpperBound(), m_Constant(&ub))) { return loopOp.emitOpError("upper bound should be a constant");