Skip to content

Commit

Permalink
[flang][OpenMP] Fix construct privatization in default clause (#72510)
Browse files Browse the repository at this point in the history
Current implementation of default clause privatization incorrectly fails
to privatize in presence of non-OpenMP constructs (i.e. nested
constructs with regions whose symbols need to be privatized in the scope
of the parent OpenMP construct). This patch fixes the same by
considering non-OpenMP constructs separately by collecting symbols of a
nested region if it is a non-OpenMP construct with a region, and
privatizing it in the scope of the parent OpenMP construct.

Fixes #71914 and
#71915
  • Loading branch information
NimishMishra authored May 3, 2024
1 parent 6b94870 commit 37f6ba4
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 22 deletions.
7 changes: 5 additions & 2 deletions flang/include/flang/Lower/AbstractConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ class AbstractConverter {
virtual bool isPresentShallowLookup(Fortran::semantics::Symbol &sym) = 0;

/// Collect the set of symbols with \p flag in \p eval
/// region if \p collectSymbols is true. Likewise, collect the
/// region if \p collectSymbols is true. Otherwise, collect the
/// set of the host symbols with \p flag of the associated symbols in \p eval
/// region if collectHostAssociatedSymbols is true.
/// region if collectHostAssociatedSymbols is true. This allows gathering
/// host association details of symbols particularly in nested directives
/// irrespective of \p flag \p, and can be useful where host
/// association details are needed in flag-agnostic manner.
virtual void collectSymbolSet(
pft::Evaluation &eval,
llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet,
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
bool collectSymbol) {
if (collectSymbol && oriSymbol.test(flag))
symbolSet.insert(&oriSymbol);
if (checkHostAssociatedSymbols)
else if (checkHostAssociatedSymbols)
if (const auto *details{
oriSymbol
.detailsIf<Fortran::semantics::HostAssocDetails>()})
Expand Down
38 changes: 27 additions & 11 deletions flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,21 +302,38 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
}
}

void DataSharingProcessor::collectSymbolsInNestedRegions(
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::Symbol::Flag flag,
llvm::SetVector<const Fortran::semantics::Symbol *>
&symbolsInNestedRegions) {
for (Fortran::lower::pft::Evaluation &nestedEval :
eval.getNestedEvaluations()) {
if (nestedEval.hasNestedEvaluations()) {
if (nestedEval.isConstruct())
// Recursively look for OpenMP constructs within `nestedEval`'s region
collectSymbolsInNestedRegions(nestedEval, flag, symbolsInNestedRegions);
else
converter.collectSymbolSet(nestedEval, symbolsInNestedRegions, flag,
/*collectSymbols=*/true,
/*collectHostAssociatedSymbols=*/false);
}
}
}

// Collect symbols to be default privatized in two steps.
// In step 1, collect all symbols in `eval` that match `flag` into
// `defaultSymbols`. In step 2, for nested constructs (if any), if and only if
// the nested construct is an OpenMP construct, collect those nested
// symbols skipping host associated symbols into `symbolsInNestedRegions`.
// Later, in current context, all symbols in the set
// `defaultSymbols` - `symbolsInNestedRegions` will be privatized.
void DataSharingProcessor::collectSymbols(
Fortran::semantics::Symbol::Flag flag) {
converter.collectSymbolSet(eval, defaultSymbols, flag,
/*collectSymbols=*/true,
/*collectHostAssociatedSymbols=*/true);
for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
if (e.hasNestedEvaluations())
converter.collectSymbolSet(e, symbolsInNestedRegions, flag,
/*collectSymbols=*/true,
/*collectHostAssociatedSymbols=*/false);
else
converter.collectSymbolSet(e, symbolsInParentRegions, flag,
/*collectSymbols=*/false,
/*collectHostAssociatedSymbols=*/true);
}
collectSymbolsInNestedRegions(eval, flag, symbolsInNestedRegions);
}

void DataSharingProcessor::collectDefaultSymbols() {
Expand Down Expand Up @@ -367,7 +384,6 @@ void DataSharingProcessor::defaultPrivatize(
!sym->GetUltimate().has<Fortran::semantics::NamelistDetails>() &&
!Fortran::semantics::IsImpliedDoIndex(sym->GetUltimate()) &&
!symbolsInNestedRegions.contains(sym) &&
!symbolsInParentRegions.contains(sym) &&
!privatizedSymbols.contains(sym))
doPrivatize(sym, clauseOps, privateSyms);
}
Expand Down
6 changes: 5 additions & 1 deletion flang/lib/Lower/OpenMP/DataSharingProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class DataSharingProcessor {
llvm::SetVector<const Fortran::semantics::Symbol *> privatizedSymbols;
llvm::SetVector<const Fortran::semantics::Symbol *> defaultSymbols;
llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInNestedRegions;
llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInParentRegions;
llvm::DenseMap<const Fortran::semantics::Symbol *, mlir::omp::PrivateClauseOp>
symToPrivatizer;
Fortran::lower::AbstractConverter &converter;
Expand All @@ -52,6 +51,11 @@ class DataSharingProcessor {

bool needBarrier();
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
void collectSymbolsInNestedRegions(
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::Symbol::Flag flag,
llvm::SetVector<const Fortran::semantics::Symbol *>
&symbolsInNestedRegions);
void collectOmpObjectListSymbol(
const omp::ObjectList &objects,
llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet);
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Lower/OpenMP/default-clause-byref.f90
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,6 @@ subroutine nested_default_clause_tests
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
!CHECK: %[[PRIVATE_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel {
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
Expand All @@ -242,12 +240,14 @@ subroutine nested_default_clause_tests
!CHECK: omp.terminator
!CHECK: }
!CHECK: omp.parallel {
!CHECK: %[[PRIVATE_INNER_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[PRIVATE_INNER_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_INNER_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
!CHECK: %[[PRIVATE_INNER_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[TEMP_1:.*]] = fir.load %[[PRIVATE_INNER_X_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_Z_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_INNER_Z_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[RESULT:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
!CHECK: hlfir.assign %[[RESULT]] to %[[PRIVATE_INNER_W_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
Expand Down
55 changes: 51 additions & 4 deletions flang/test/Lower/OpenMP/default-clause.f90
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ program default_clause_lowering
end program default_clause_lowering

subroutine nested_default_clause_tests
integer :: x, y, z, w, k, a
integer :: x, y, z, w, k
!CHECK: %[[K:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFnested_default_clause_testsEk"}
!CHECK: %[[K_DECL:.*]]:2 = hlfir.declare %[[K]] {uniq_name = "_QFnested_default_clause_testsEk"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[W:.*]] = fir.alloca i32 {bindc_name = "w", uniq_name = "_QFnested_default_clause_testsEw"}
Expand Down Expand Up @@ -221,13 +221,12 @@ subroutine nested_default_clause_tests


!CHECK: omp.parallel {
!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_default_clause_testsEy"}
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
!CHECK: %[[PRIVATE_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel {
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
Expand All @@ -242,12 +241,14 @@ subroutine nested_default_clause_tests
!CHECK: omp.terminator
!CHECK: }
!CHECK: omp.parallel {
!CHECK: %[[PRIVATE_INNER_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[PRIVATE_INNER_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_INNER_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
!CHECK: %[[PRIVATE_INNER_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[TEMP_1:.*]] = fir.load %[[PRIVATE_INNER_X_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_Z_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_INNER_Z_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[RESULT:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
!CHECK: hlfir.assign %[[RESULT]] to %[[PRIVATE_INNER_W_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
Expand Down Expand Up @@ -415,3 +416,49 @@ subroutine threadprivate_with_default
end do
!$omp end parallel do
end subroutine

subroutine nested_constructs
!CHECK: %[[I:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFnested_constructsEi"}
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[I]] {{.*}}
!CHECK: %[[J:.*]] = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFnested_constructsEj"}
!CHECK: %[[J_DECL:.*]]:2 = hlfir.declare %[[J]] {{.*}}
!CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFnested_constructsEy"}
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {{.*}}
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_constructsEz"}
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {{.*}}

integer :: y, z
!CHECK: omp.parallel {
!CHECK: %[[INNER_J:.*]] = fir.alloca i32 {bindc_name = "j", pinned}
!CHECK: %[[INNER_J_DECL:.*]]:2 = hlfir.declare %[[INNER_J]] {{.*}}
!CHECK: %[[INNER_I:.*]] = fir.alloca i32 {bindc_name = "i", pinned}
!CHECK: %[[INNER_I_DECL:.*]]:2 = hlfir.declare %[[INNER_I]] {{.*}}
!CHECK: %[[INNER_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_constructsEy"}
!CHECK: %[[INNER_Y_DECL:.*]]:2 = hlfir.declare %[[INNER_Y]] {{.*}}
!CHECK: %[[TEMP:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
!CHECK: hlfir.assign %[[TEMP]] to %[[INNER_Y_DECL]]#0 temporary_lhs : i32, !fir.ref<i32>
!CHECK: %[[INNER_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_constructsEz"}
!CHECK: %[[INNER_Z_DECL:.*]]:2 = hlfir.declare %[[INNER_Z]] {{.*}}
!$omp parallel default(private) firstprivate(y)
!CHECK: {{.*}} = fir.do_loop {{.*}} {
do i = 1, 10
!CHECK: %[[CONST_1:.*]] = arith.constant 1 : i32
!CHECK: hlfir.assign %[[CONST_1]] to %[[INNER_Y_DECL]]#0 : i32, !fir.ref<i32>
y = 1
!CHECK: {{.*}} = fir.do_loop {{.*}} {
do j = 1, 10
!CHECK: %[[CONST_20:.*]] = arith.constant 20 : i32
!CHECK: hlfir.assign %[[CONST_20]] to %[[INNER_Z_DECL]]#0 : i32, !fir.ref<i32>
z = 20
!CHECK: omp.parallel {
!CHECK: %[[NESTED_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_constructsEy"}
!CHECK: %[[NESTED_Y_DECL:.*]]:2 = hlfir.declare %[[NESTED_Y]] {{.*}}
!CHECK: %[[CONST_2:.*]] = arith.constant 2 : i32
!CHECK: hlfir.assign %[[CONST_2]] to %[[NESTED_Y_DECL]]#0 : i32, !fir.ref<i32>
!$omp parallel default(private)
y = 2
!$omp end parallel
end do
end do
!$omp end parallel
end subroutine

0 comments on commit 37f6ba4

Please sign in to comment.