-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OpenMPIRBuilder] Added if
clause for teams
#69139
Conversation
This patch adds support for the `if` clause on `teams` construct. The value of the argument must be an integer value. If the value evaluates to true (non-zero) integer, then the number of threads is determined by `num_threads` clause (or default and ICV if `num_threads` is absent). When the condition evaluates to false (zero), then the bounds are set to 1. This essentially means that ``` upperbound = ifexpr ? upperbound : 1 lowerbound = ifexpr ? lowerbound : 1 ```
@llvm/pr-subscribers-flang-openmp Author: Shraiysh (shraiysh) ChangesThis patch adds support for the This essentially means that
Full diff: https://github.com/llvm/llvm-project/pull/69139.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 9d2adf229b78654..00b4707a7f820d7 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1923,11 +1923,12 @@ class OpenMPIRBuilder {
/// \param NumTeamsUpper Upper bound on the number of teams.
/// \param ThreadLimit on the number of threads that may participate in a
/// contention group created by each team.
- InsertPointTy createTeams(const LocationDescription &Loc,
- BodyGenCallbackTy BodyGenCB,
- Value *NumTeamsLower = nullptr,
- Value *NumTeamsUpper = nullptr,
- Value *ThreadLimit = nullptr);
+ /// \param IfExpr is the integer argument value of the if condition on the
+ /// teams clause.
+ InsertPointTy
+ createTeams(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
+ Value *NumTeamsLower = nullptr, Value *NumTeamsUpper = nullptr,
+ Value *ThreadLimit = nullptr, Value *IfExpr = nullptr);
/// Generate conditional branch and relevant BasicBlocks through which private
/// threads copy the 'copyin' variables from Master copy to threadprivate
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a658990f2d45355..5b24e9fe2e0c5bd 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5734,7 +5734,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
- Value *NumTeamsUpper, Value *ThreadLimit) {
+ Value *NumTeamsUpper, Value *ThreadLimit,
+ Value *IfExpr) {
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -5773,7 +5774,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
// Push num_teams
- if (NumTeamsLower || NumTeamsUpper || ThreadLimit) {
+ if (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr) {
assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
"if lowerbound is non-null, then upperbound must also be non-null "
"for bounds on num_teams");
@@ -5784,6 +5785,22 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
if (NumTeamsLower == nullptr)
NumTeamsLower = NumTeamsUpper;
+ if (IfExpr) {
+ assert(IfExpr->getType()->isIntegerTy() &&
+ "argument to if clause must be an integer value");
+
+ // upper = ifexpr ? upper : 1
+ if (IfExpr->getType() != Int1)
+ IfExpr = Builder.CreateICmpNE(IfExpr,
+ ConstantInt::get(IfExpr->getType(), 0));
+ NumTeamsUpper = Builder.CreateSelect(
+ IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");
+
+ // lower = ifexpr ? lower : 1
+ NumTeamsLower = Builder.CreateSelect(
+ IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
+ }
+
if (ThreadLimit == nullptr)
ThreadLimit = Builder.getInt32(0);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index d770facc1730252..97cfc339675f657 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4033,7 +4033,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
};
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
- Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
+ Builder.restoreIP(OMPBuilder.createTeams(
+ Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
+ /*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));
OMPBuilder.finalize();
Builder.CreateRetVoid();
@@ -4095,7 +4097,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB,
/*NumTeamsLower=*/nullptr,
/*NumTeamsUpper=*/nullptr,
- /*ThreadLimit=*/F->arg_begin()));
+ /*ThreadLimit=*/F->arg_begin(),
+ /*IfExpr=*/nullptr));
Builder.CreateRetVoid();
OMPBuilder.finalize();
@@ -4144,7 +4147,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
// `num_teams`
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB,
/*NumTeamsLower=*/nullptr,
- /*NumTeamsUpper=*/F->arg_begin()));
+ /*NumTeamsUpper=*/F->arg_begin(),
+ /*ThreadLimit=*/nullptr,
+ /*IfExpr=*/nullptr));
Builder.CreateRetVoid();
OMPBuilder.finalize();
@@ -4197,7 +4202,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
// `F` already has an integer argument, so we use that as upper bound to
// `num_teams`
Builder.restoreIP(
- OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper));
+ OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper,
+ /*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));
Builder.CreateRetVoid();
OMPBuilder.finalize();
@@ -4255,8 +4261,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
};
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
- Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
- NumTeamsUpper, ThreadLimit));
+ Builder.restoreIP(OMPBuilder.createTeams(
+ Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, nullptr));
Builder.CreateRetVoid();
OMPBuilder.finalize();
@@ -4284,6 +4290,134 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
}
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfCondition) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> &Builder = OMPBuilder.Builder;
+ Builder.SetInsertPoint(BB);
+
+ Value *IfExpr = Builder.CreateLoad(Builder.getInt1Ty(),
+ Builder.CreateAlloca(Builder.getInt1Ty()));
+
+ Function *FakeFunction =
+ Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ Builder.CreateCall(FakeFunction, {});
+ };
+
+ // `F` already has an integer argument, so we use that as upper bound to
+ // `num_teams`
+ Builder.restoreIP(OMPBuilder.createTeams(
+ Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
+ /*ThreadLimit=*/nullptr, IfExpr));
+
+ Builder.CreateRetVoid();
+ OMPBuilder.finalize();
+
+ ASSERT_FALSE(verifyModule(*M));
+
+ CallInst *PushNumTeamsCallInst =
+ findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
+ ASSERT_NE(PushNumTeamsCallInst, nullptr);
+ Value *NumTeamsLower = PushNumTeamsCallInst->getArgOperand(2);
+ Value *NumTeamsUpper = PushNumTeamsCallInst->getArgOperand(3);
+ Value *ThreadLimit = PushNumTeamsCallInst->getArgOperand(4);
+
+ // Check the lower_bound
+ ASSERT_NE(NumTeamsLower, nullptr);
+ SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLower);
+ ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
+ EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExpr);
+ EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), Builder.getInt32(0));
+ EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));
+
+ // Check the upper_bound
+ ASSERT_NE(NumTeamsUpper, nullptr);
+ SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpper);
+ ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
+ EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExpr);
+ EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), Builder.getInt32(0));
+ EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));
+
+ // Check thread_limit
+ EXPECT_EQ(ThreadLimit, Builder.getInt32(0));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfConditionAndNumTeams) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> &Builder = OMPBuilder.Builder;
+ Builder.SetInsertPoint(BB);
+
+ Value *IfExpr = Builder.CreateLoad(
+ Builder.getInt32Ty(), Builder.CreateAlloca(Builder.getInt32Ty()));
+ Value *NumTeamsLower = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5));
+ Value *NumTeamsUpper =
+ Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10));
+ Value *ThreadLimit = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20));
+
+ Function *FakeFunction =
+ Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ Builder.CreateCall(FakeFunction, {});
+ };
+
+ // `F` already has an integer argument, so we use that as upper bound to
+ // `num_teams`
+ Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
+ NumTeamsUpper, ThreadLimit, IfExpr));
+
+ Builder.CreateRetVoid();
+ OMPBuilder.finalize();
+
+ ASSERT_FALSE(verifyModule(*M));
+
+ CallInst *PushNumTeamsCallInst =
+ findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
+ ASSERT_NE(PushNumTeamsCallInst, nullptr);
+ Value *NumTeamsLowerArg = PushNumTeamsCallInst->getArgOperand(2);
+ Value *NumTeamsUpperArg = PushNumTeamsCallInst->getArgOperand(3);
+ Value *ThreadLimitArg = PushNumTeamsCallInst->getArgOperand(4);
+
+ // Get the boolean conversion of if expression
+ ASSERT_EQ(IfExpr->getNumUses(), 1U);
+ User *IfExprInst = IfExpr->user_back();
+ ICmpInst *IfExprCmpInst = dyn_cast<ICmpInst>(IfExprInst);
+ ASSERT_NE(IfExprCmpInst, nullptr);
+ EXPECT_EQ(IfExprCmpInst->getPredicate(), ICmpInst::Predicate::ICMP_NE);
+ EXPECT_EQ(IfExprCmpInst->getOperand(0), IfExpr);
+ EXPECT_EQ(IfExprCmpInst->getOperand(1), Builder.getInt32(0));
+
+ // Check the lower_bound
+ ASSERT_NE(NumTeamsLowerArg, nullptr);
+ SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLowerArg);
+ ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
+ EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExprCmpInst);
+ EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), NumTeamsLower);
+ EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));
+
+ // Check the upper_bound
+ ASSERT_NE(NumTeamsUpperArg, nullptr);
+ SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpperArg);
+ ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
+ EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExprCmpInst);
+ EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), NumTeamsUpper);
+ EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));
+
+ // Check thread_limit
+ EXPECT_EQ(ThreadLimitArg, ThreadLimit);
+}
+
/// Returns the single instruction of InstTy type in BB that uses the value V.
/// If there is more than one such instruction, returns null.
template <typename InstTy>
|
Is this mentioned in the standard? Or is it specific to the llvm openmp library? If the former, please add a pointer to the standard. |
This is mentioned in the standard. Added reference to it in the description. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM.
This patch adds support for the
if
clause onteams
construct. The value of the argument must be an integer value. If the value evaluates to true (non-zero) integer, then the number of threads is determined bynum_threads
clause (or default and ICV ifnum_threads
is absent). When the condition evaluates to false (zero), then the bounds are set to 1. (OpenMP 5.2 Section 10.2)This essentially means that