Skip to content

Commit

Permalink
[mlir][scf] Implement getSingle... of LoopLikeOpinterface for scf::Pa…
Browse files Browse the repository at this point in the history
…rallelOp (llvm#68511)

This adds implementations for `getSingleIterationVar`,
`getSingleLowerBound`, `getSingleUpperBound`, `getSingleStep` of
`LoopLikeOpInterface` to `scf::ParallelOp`. Until now, the
implementations for these methods defaulted to returning `std::nullopt`,
even in the special case where the parallel Op only has one dimension.

Related: llvm#67883
  • Loading branch information
ubfx authored Oct 20, 2023
1 parent b8ad68f commit aa0208d
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
def ParallelOp : SCF_Op<"parallel",
[AutomaticAllocationScope,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getSingleInductionVar",
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"scf::YieldOp">]> {
Expand Down
24 changes: 24 additions & 0 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2936,6 +2936,30 @@ void ParallelOp::print(OpAsmPrinter &p) {

SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }

std::optional<Value> ParallelOp::getSingleInductionVar() {
if (getNumLoops() != 1)
return std::nullopt;
return getBody()->getArgument(0);
}

std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
if (getNumLoops() != 1)
return std::nullopt;
return getLowerBound()[0];
}

std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
if (getNumLoops() != 1)
return std::nullopt;
return getUpperBound()[0];
}

std::optional<OpFoldResult> ParallelOp::getSingleStep() {
if (getNumLoops() != 1)
return std::nullopt;
return getStep()[0];
}

ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
if (!ivArg)
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ target_link_libraries(MLIRDialectTests
add_subdirectory(Index)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
add_subdirectory(SCF)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
add_subdirectory(Transform)
Expand Down
8 changes: 8 additions & 0 deletions mlir/unittests/Dialect/SCF/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_mlir_unittest(MLIRSCFTests
LoopLikeSCFOpsTest.cpp
)
target_link_libraries(MLIRSCFTests
PRIVATE
MLIRIR
MLIRSCFDialect
)
89 changes: 89 additions & 0 deletions mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "gtest/gtest.h"

using namespace mlir;
using namespace mlir::scf;

//===----------------------------------------------------------------------===//
// Test Fixture
//===----------------------------------------------------------------------===//

class SCFLoopLikeTest : public ::testing::Test {
protected:
SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
}

void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
EXPECT_TRUE(maybeLb.has_value());
std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
EXPECT_TRUE(maybeUb.has_value());
std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
EXPECT_TRUE(maybeStep.has_value());
std::optional<OpFoldResult> maybeIndVar =
loopLikeOp.getSingleInductionVar();
EXPECT_TRUE(maybeIndVar.has_value());
}

void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
EXPECT_FALSE(maybeLb.has_value());
std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
EXPECT_FALSE(maybeUb.has_value());
std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
EXPECT_FALSE(maybeStep.has_value());
std::optional<OpFoldResult> maybeIndVar =
loopLikeOp.getSingleInductionVar();
EXPECT_FALSE(maybeIndVar.has_value());
}

MLIRContext context;
OpBuilder b;
Location loc;
};

TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
Value lb = b.create<arith::ConstantIndexOp>(loc, 0);
Value ub = b.create<arith::ConstantIndexOp>(loc, 10);
Value step = b.create<arith::ConstantIndexOp>(loc, 2);

auto forOp = b.create<scf::ForOp>(loc, lb, ub, step);
checkUnidimensional(forOp);

auto forallOp = b.create<scf::ForallOp>(
loc, ArrayRef<OpFoldResult>(lb), ArrayRef<OpFoldResult>(ub),
ArrayRef<OpFoldResult>(step), ValueRange(), std::nullopt);
checkUnidimensional(forallOp);

auto parallelOp = b.create<scf::ParallelOp>(
loc, ValueRange(lb), ValueRange(ub), ValueRange(step), ValueRange());
checkUnidimensional(parallelOp);
}

TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
Value lb = b.create<arith::ConstantIndexOp>(loc, 0);
Value ub = b.create<arith::ConstantIndexOp>(loc, 10);
Value step = b.create<arith::ConstantIndexOp>(loc, 2);

auto forallOp = b.create<scf::ForallOp>(
loc, ArrayRef<OpFoldResult>({lb, lb}), ArrayRef<OpFoldResult>({ub, ub}),
ArrayRef<OpFoldResult>({step, step}), ValueRange(), std::nullopt);
checkMultidimensional(forallOp);

auto parallelOp =
b.create<scf::ParallelOp>(loc, ValueRange({lb, lb}), ValueRange({ub, ub}),
ValueRange({step, step}), ValueRange());
checkMultidimensional(parallelOp);
}

0 comments on commit aa0208d

Please sign in to comment.