Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][openacc] Fix unstructured code in OpenACC region ops #66284

Merged
merged 4 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion flang/lib/Lower/DirectivesCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,31 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
firOpBuilder.setInsertionPointToStart(&block);
}

/// Create empty blocks for the current region.
/// These blocks replace blocks parented to an enclosing region.
template <typename... TerminatorOps>
void createEmptyRegionBlocks(
fir::FirOpBuilder &builder,
std::list<Fortran::lower::pft::Evaluation> &evaluationList) {
mlir::Region *region = &builder.getRegion();
for (Fortran::lower::pft::Evaluation &eval : evaluationList) {
if (eval.block) {
if (eval.block->empty()) {
eval.block->erase();
eval.block = builder.createBlock(region);
} else {
[[maybe_unused]] mlir::Operation &terminatorOp = eval.block->back();
assert(mlir::isa<TerminatorOps...>(terminatorOp) &&
"expected terminator op");
}
}
if (!eval.isDirective() && eval.hasNestedEvaluations())
createEmptyRegionBlocks<TerminatorOps...>(builder,
eval.getNestedEvaluations());
}
}

} // namespace lower
} // namespace Fortran

#endif // FORTRAN_LOWER_DIRECTIVES_COMMON_H
#endif // FORTRAN_LOWER_DIRECTIVES_COMMON_H
87 changes: 56 additions & 31 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1250,23 +1250,32 @@ static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
}

template <typename Op, typename Terminator>
static Op
createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
const llvm::SmallVectorImpl<mlir::Value> &operands,
const llvm::SmallVectorImpl<int32_t> &operandSegments) {
static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::pft::Evaluation &eval,
const llvm::SmallVectorImpl<mlir::Value> &operands,
const llvm::SmallVectorImpl<int32_t> &operandSegments,
bool outerCombined = false) {
llvm::ArrayRef<mlir::Type> argTy;
Op op = builder.create<Op>(loc, argTy, operands);
builder.createBlock(&op.getRegion());
mlir::Block &block = op.getRegion().back();
builder.setInsertionPointToStart(&block);
builder.create<Terminator>(loc);

op->setAttr(Op::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr(operandSegments));

// Place the insertion point to the start of the first block.
builder.setInsertionPointToStart(&block);

// If it is an unstructured region and is not the outer region of a combined
// construct, create empty blocks for all evaluations.
if (eval.lowerAsUnstructured() && !outerCombined)
Fortran::lower::createEmptyRegionBlocks<mlir::acc::TerminatorOp,
mlir::acc::YieldOp>(
builder, eval.getNestedEvaluations());

builder.create<Terminator>(loc);
builder.setInsertionPointToStart(&block);
return op;
}

Expand Down Expand Up @@ -1347,6 +1356,7 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
static mlir::acc::LoopOp
createLoopOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::AccClauseList &accClauseList) {
Expand Down Expand Up @@ -1455,7 +1465,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
addOperands(operands, operandSegments, cacheOperands);

auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
builder, currentLocation, operands, operandSegments);
builder, currentLocation, eval, operands, operandSegments);

if (hasGang)
loopOp.setHasGangAttr(builder.getUnitAttr());
Expand Down Expand Up @@ -1504,6 +1514,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,

static void genACC(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {

const auto &beginLoopDirective =
Expand All @@ -1518,7 +1529,7 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
if (loopDirective.v == llvm::acc::ACCD_loop) {
const auto &accClauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList);
}
}
Expand Down Expand Up @@ -1551,9 +1562,11 @@ template <typename Op>
static Op
createComputeOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::AccClauseList &accClauseList) {
const Fortran::parser::AccClauseList &accClauseList,
bool outerCombined = false) {

// Parallel operation operands
mlir::Value async;
Expand Down Expand Up @@ -1769,10 +1782,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
Op computeOp;
if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
builder, currentLocation, operands, operandSegments);
builder, currentLocation, eval, operands, operandSegments,
outerCombined);
else
computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
builder, currentLocation, operands, operandSegments);
builder, currentLocation, eval, operands, operandSegments,
outerCombined);

if (addAsyncAttr)
computeOp.setAsyncAttrAttr(builder.getUnitAttr());
Expand Down Expand Up @@ -1817,6 +1832,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,

static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::AccClauseList &accClauseList) {
Expand Down Expand Up @@ -1942,7 +1958,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
return;

auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
builder, currentLocation, operands, operandSegments);
builder, currentLocation, eval, operands, operandSegments);

dataOp.setAsyncAttr(addAsyncAttr);
dataOp.setWaitAttr(addWaitAttr);
Expand Down Expand Up @@ -1971,6 +1987,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
static void
genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::AccClauseList &accClauseList) {
Expand Down Expand Up @@ -2020,7 +2037,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,

auto hostDataOp =
createRegionOp<mlir::acc::HostDataOp, mlir::acc::TerminatorOp>(
builder, currentLocation, operands, operandSegments);
builder, currentLocation, eval, operands, operandSegments);

if (addIfPresentAttr)
hostDataOp.setIfPresentAttr(builder.getUnitAttr());
Expand All @@ -2029,6 +2046,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
static void
genACC(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
const auto &beginBlockDirective =
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
Expand All @@ -2041,26 +2059,30 @@ genACC(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;

if (blockDirective.v == llvm::acc::ACCD_parallel) {
createComputeOp<mlir::acc::ParallelOp>(
converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
createComputeOp<mlir::acc::ParallelOp>(converter, currentLocation, eval,
semanticsContext, stmtCtx,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_data) {
genACCDataOp(converter, currentLocation, semanticsContext, stmtCtx,
genACCDataOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_serial) {
createComputeOp<mlir::acc::SerialOp>(
converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
semanticsContext, stmtCtx,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_kernels) {
createComputeOp<mlir::acc::KernelsOp>(
converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
createComputeOp<mlir::acc::KernelsOp>(converter, currentLocation, eval,
semanticsContext, stmtCtx,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
genACCHostDataOp(converter, currentLocation, semanticsContext, stmtCtx,
accClauseList);
genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
stmtCtx, accClauseList);
}
}

static void
genACC(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) {
const auto &beginCombinedDirective =
std::get<Fortran::parser::AccBeginCombinedDirective>(combinedConstruct.t);
Expand All @@ -2075,18 +2097,21 @@ genACC(Fortran::lower::AbstractConverter &converter,

if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
createComputeOp<mlir::acc::KernelsOp>(
converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList, /*outerCombined=*/true);
createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList);
} else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
createComputeOp<mlir::acc::ParallelOp>(
converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList, /*outerCombined=*/true);
createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList);
} else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
createComputeOp<mlir::acc::SerialOp>(
converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
semanticsContext, stmtCtx,
accClauseList, /*outerCombined=*/true);
createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList);
} else {
llvm::report_fatal_error("Unknown combined construct encountered");
Expand Down Expand Up @@ -3169,14 +3194,14 @@ void Fortran::lower::genOpenACCConstruct(
std::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
genACC(converter, semanticsContext, blockConstruct);
genACC(converter, semanticsContext, eval, blockConstruct);
},
[&](const Fortran::parser::OpenACCCombinedConstruct
&combinedConstruct) {
genACC(converter, semanticsContext, combinedConstruct);
genACC(converter, semanticsContext, eval, combinedConstruct);
},
[&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
genACC(converter, semanticsContext, loopConstruct);
genACC(converter, semanticsContext, eval, loopConstruct);
},
[&](const Fortran::parser::OpenACCStandaloneConstruct
&standaloneConstruct) {
Expand Down
27 changes: 3 additions & 24 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1981,29 +1981,6 @@ static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
}

/// Create empty blocks for the current region.
/// These blocks replace blocks parented to an enclosing region.
static void createEmptyRegionBlocks(
fir::FirOpBuilder &firOpBuilder,
std::list<Fortran::lower::pft::Evaluation> &evaluationList) {
mlir::Region *region = &firOpBuilder.getRegion();
for (Fortran::lower::pft::Evaluation &eval : evaluationList) {
if (eval.block) {
if (eval.block->empty()) {
eval.block->erase();
eval.block = firOpBuilder.createBlock(region);
} else {
[[maybe_unused]] mlir::Operation &terminatorOp = eval.block->back();
assert((mlir::isa<mlir::omp::TerminatorOp>(terminatorOp) ||
mlir::isa<mlir::omp::YieldOp>(terminatorOp)) &&
"expected terminator op");
}
}
if (!eval.isDirective() && eval.hasNestedEvaluations())
createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
}
}

static void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder,
mlir::Operation *storeOp,
mlir::Block &block) {
Expand Down Expand Up @@ -2092,7 +2069,9 @@ static void createBodyOfOp(
// If it is an unstructured region and is not the outer region of a combined
// construct, create empty blocks for all evaluations.
if (eval.lowerAsUnstructured() && !outerCombined)
createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
mlir::omp::YieldOp>(
firOpBuilder, eval.getNestedEvaluations());

// Insert the terminator.
if constexpr (std::is_same_v<Op, mlir::omp::WsLoopOp> ||
Expand Down
86 changes: 86 additions & 0 deletions flang/test/Lower/OpenACC/acc-unstructured.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s

subroutine test_unstructured1(a, b, c)
integer :: i, j, k
real :: a(:,:,:), b(:,:,:), c(:,:,:)

!$acc data copy(a, b, c)

!$acc kernels
a(:,:,:) = 0.0
!$acc end kernels

!$acc kernels
do i = 1, 10
do j = 1, 10
do k = 1, 10
end do
end do
end do
!$acc end kernels

do i = 1, 10
do j = 1, 10
do k = 1, 10
end do
end do

if (a(1,2,3) > 10) stop 'just to be unstructured'
end do

!$acc end data

end subroutine

! CHECK-LABEL: func.func @_QPtest_unstructured1
! CHECK: acc.data
! CHECK: acc.kernels
! CHECK: acc.kernels
! CHECK: fir.call @_FortranAStopStatementText


subroutine test_unstructured2(a, b, c)
integer :: i, j, k
real :: a(:,:,:), b(:,:,:), c(:,:,:)

!$acc parallel loop
do i = 1, 10
do j = 1, 10
do k = 1, 10
if (a(1,2,3) > 10) stop 'just to be unstructured'
end do
end do
end do

! CHECK-LABEL: func.func @_QPtest_unstructured2
! CHECK: acc.parallel
! CHECK: acc.loop
! CHECK: fir.call @_FortranAStopStatementText
! CHECK: fir.unreachable
! CHECK: acc.yield
! CHECK: acc.yield

end subroutine

subroutine test_unstructured3(a, b, c)
integer :: i, j, k
real :: a(:,:,:), b(:,:,:), c(:,:,:)

!$acc parallel
do i = 1, 10
do j = 1, 10
do k = 1, 10
if (a(1,2,3) > 10) stop 'just to be unstructured'
end do
end do
end do
!$acc end parallel

! CHECK-LABEL: func.func @_QPtest_unstructured3
! CHECK: acc.parallel
! CHECK: fir.call @_FortranAStopStatementText
! CHECK: fir.unreachable
! CHECK: acc.yield

end subroutine
2 changes: 1 addition & 1 deletion flang/test/Lower/OpenACC/stop-stmt-in-region.f90
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ subroutine test_stop_in_region1()
! CHECK: %[[VAL_2:.*]] = arith.constant false
! CHECK: %[[VAL_3:.*]] = arith.constant false
! CHECK: %[[VAL_4:.*]] = fir.call @_FortranAStopStatement(%[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) {{.*}} : (i32, i1, i1) -> none
! CHECK: acc.yield
! CHECK: fir.unreachable
! CHECK: }
! CHECK: return
! CHECK: }
Expand Down