Skip to content

Commit

Permalink
[flang][openacc] Add support for allocatable and pointer arrays in re…
Browse files Browse the repository at this point in the history
…duction (#68261)

This patch adds support for allocatable and pointer arrays in the
reduction recipe lowering.
  • Loading branch information
clementval authored Oct 5, 2023
1 parent 017a003 commit 964a252
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 14 deletions.
3 changes: 3 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRType.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,9 @@ bool isAssumedType(mlir::Type ty);
/// Return true iff `ty` is the type of an assumed shape array.
bool isAssumedShape(mlir::Type ty);

/// Return true iff `ty` is the type of an allocatable array.
bool isAllocatableOrPointerArray(mlir::Type ty);

/// Return true iff `boxTy` wraps a record type or an unlimited polymorphic
/// entity. Polymorphic entities with intrinsic type spec do not have addendum
inline bool boxHasAddendum(fir::BaseBoxType boxTy) {
Expand Down
50 changes: 42 additions & 8 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,9 +742,28 @@ static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder,
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
return getReductionInitValue(builder, loc, boxTy.getEleTy(), op);

if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
return getReductionInitValue(builder, loc, heapTy.getEleTy(), op);

if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
return getReductionInitValue(builder, loc, ptrTy.getEleTy(), op);

llvm::report_fatal_error("Unsupported OpenACC reduction type");
}

/// Return the nested sequence type if any.
static mlir::Type extractSequenceType(mlir::Type ty) {
if (mlir::isa<fir::SequenceType>(ty))
return ty;
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
return extractSequenceType(boxTy.getEleTy());
if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
return extractSequenceType(heapTy.getEleTy());
if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
return extractSequenceType(ptrTy.getEleTy());
return mlir::Type{};
}

static mlir::Value genReductionInitRegion(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Type ty,
mlir::acc::ReductionOperator op) {
Expand Down Expand Up @@ -788,7 +807,8 @@ static mlir::Value genReductionInitRegion(fir::FirOpBuilder &builder,
return declareOp.getBase();
}
} else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
if (!mlir::isa<fir::SequenceType>(boxTy.getEleTy()))
mlir::Type innerTy = extractSequenceType(boxTy);
if (!mlir::isa<fir::SequenceType>(innerTy))
TODO(loc, "Unsupported boxed type for reduction");
// Create the private copy from the initial fir.box.
hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
Expand Down Expand Up @@ -993,8 +1013,9 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
builder.setInsertionPointAfter(loops[0]);
} else if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
llvm::SmallVector<mlir::Value> tripletArgs;
mlir::Type innerTy = extractSequenceType(boxTy);
fir::SequenceType seqTy =
mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy());
mlir::dyn_cast_or_null<fir::SequenceType>(innerTy);
if (!seqTy)
TODO(loc, "Unsupported boxed type in OpenACC reduction");

Expand Down Expand Up @@ -1110,6 +1131,19 @@ mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe(
return recipe;
}

static bool isSupportedReductionType(mlir::Type ty) {
ty = fir::unwrapRefType(ty);
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
return isSupportedReductionType(boxTy.getEleTy());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
return isSupportedReductionType(seqTy.getEleTy());
if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
return isSupportedReductionType(heapTy.getEleTy());
if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
return isSupportedReductionType(ptrTy.getEleTy());
return fir::isa_trivial(ty);
}

static void
genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
Fortran::lower::AbstractConverter &converter,
Expand All @@ -1135,24 +1169,24 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
reductionTy = seqTy.getEleTy();

if (!fir::isa_trivial(reductionTy) &&
((fir::isAllocatableType(reductionTy) ||
fir::isPointerType(reductionTy)) &&
!bounds.empty()))
if (!isSupportedReductionType(reductionTy))
TODO(operandLocation, "reduction with unsupported type");

auto op = createDataEntryOp<mlir::acc::ReductionOp>(
builder, operandLocation, baseAddr, asFortran, bounds,
/*structured=*/true, /*implicit=*/false,
mlir::acc::DataClause::acc_reduction, baseAddr.getType());
mlir::Type ty = op.getAccPtr().getType();
if (!areAllBoundConstant(bounds) ||
fir::isAssumedShape(baseAddr.getType()) ||
fir::isAllocatableOrPointerArray(baseAddr.getType()))
ty = baseAddr.getType();
std::string suffix =
areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
std::string recipeName = fir::getTypeAsString(
ty, converter.getKindMap(),
("reduction_" + stringifyReductionOperator(mlirOp)).str() + suffix);
if (!areAllBoundConstant(bounds) || fir::isAssumedShape(baseAddr.getType()))
ty = baseAddr.getType();

mlir::acc::ReductionRecipeOp recipe =
Fortran::lower::createOrGetReductionRecipe(
builder, recipeName, operandLocation, ty, mlirOp, bounds);
Expand Down
12 changes: 12 additions & 0 deletions flang/lib/Optimizer/Dialect/FIRType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,18 @@ bool isAssumedShape(mlir::Type ty) {
return false;
}

bool isAllocatableOrPointerArray(mlir::Type ty) {
if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
ty = refTy;
if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) {
if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy()))
return mlir::isa<fir::SequenceType>(heapTy.getEleTy());
if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy()))
return mlir::isa<fir::SequenceType>(ptrTy.getEleTy());
}
return false;
}

bool isPolymorphicType(mlir::Type ty) {
if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
ty = refTy;
Expand Down
77 changes: 71 additions & 6 deletions flang/test/Lower/OpenACC/acc-reduction.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,42 @@
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s --check-prefixes=CHECK,FIR
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s --check-prefixes=CHECK,HLFIR

! CHECK-LABEL: acc.reduction.recipe @reduction_add_section_lb1.ub3_ref_Uxi32 : !fir.box<!fir.array<?xi32>> reduction_operator <add> init {
! CHECK-LABEL: acc.reduction.recipe @reduction_max_box_ptr_Uxf32 : !fir.box<!fir.ptr<!fir.array<?xf32>>> reduction_operator <max> init {
! CHECK: ^bb0(%{{.*}}: !fir.box<!fir.ptr<!fir.array<?xf32>>>):
! CHECK: } combiner {
! CHECK: ^bb0(%{{.*}}: !fir.box<!fir.ptr<!fir.array<?xf32>>>, %{{.*}}: !fir.box<!fir.ptr<!fir.array<?xf32>>>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
! CHECK: }

! CHECK-LABEL: acc.reduction.recipe @reduction_max_box_heap_Uxf32 : !fir.box<!fir.heap<!fir.array<?xf32>>> reduction_operator <max> init {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.heap<!fir.array<?xf32>>>):
! HLFIR: %[[CST:.*]] = arith.constant -1.401300e-45 : f32
! HLFIR: %[[C0:.*]] = arith.constant 0 : index
! HLFIR: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
! HLFIR: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1 : (index) -> !fir.shape<1>
! HLFIR: %[[TEMP:.*]] = fir.allocmem !fir.array<?xf32>, %[[BOX_DIMS]]#1 {bindc_name = ".tmp", uniq_name = ""}
! HLFIR: %[[DECLARE:.*]]:2 = hlfir.declare %2(%1) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.heap<!fir.array<?xf32>>)
! HLFIR: hlfir.assign %[[CST]] to %[[DECLARE]]#0 : f32, !fir.box<!fir.array<?xf32>>
! HLFIR: acc.yield %[[DECLARE]]#0 : !fir.box<!fir.array<?xf32>>
! CHECK: } combiner {
! HLFIR: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.heap<!fir.array<?xf32>>>, %[[ARG1:.*]]: !fir.box<!fir.heap<!fir.array<?xf32>>>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
! HLFIR: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
! HLFIR: %[[DES_V1:.*]] = hlfir.designate %[[ARG0]] (%[[ARG2]]:%[[ARG3]]:%[[ARG4]]) shape %[[SHAPE]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
! HLFIR: %[[DES_V2:.*]] = hlfir.designate %[[ARG1]] (%[[ARG2]]:%[[ARG3]]:%[[ARG4]]) shape %[[SHAPE]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
! HLFIR: %[[ELEMENTAL:.*]] = hlfir.elemental %[[SHAPE]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
! HLFIR: ^bb0(%[[IV:.*]]: index):
! HLFIR: %[[V1:.*]] = hlfir.designate %[[DES_V1]] (%[[IV]]) : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
! HLFIR: %[[V2:.*]] = hlfir.designate %[[DES_V2]] (%[[IV]]) : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
! HLFIR: %[[LOAD_V1:.*]] = fir.load %[[V1]] : !fir.ref<f32>
! HLFIR: %[[LOAD_V2:.*]] = fir.load %[[V2]] : !fir.ref<f32>
! HLFIR: %[[CMP:.*]] = arith.cmpf ogt, %[[LOAD_V1]], %[[LOAD_V2]] : f32
! HLFIR: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LOAD_V1]], %[[LOAD_V2]] : f32
! HLFIR: hlfir.yield_element %[[SELECT]] : f32
! HLFIR: }
! HLFIR: hlfir.assign %[[ELEMENTAL]] to %[[ARG0]] : !hlfir.expr<?xf32>, !fir.box<!fir.heap<!fir.array<?xf32>>>
! HLFIR: acc.yield %[[ARG0]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
! CHECK: }

! CHECK-LABEL: acc.reduction.recipe @reduction_add_section_lb1.ub3_box_Uxi32 : !fir.box<!fir.array<?xi32>> reduction_operator <add> init {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.array<?xi32>>):
! HLFIR: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %c0{{.*}} : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
! HLFIR: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1 : (index) -> !fir.shape<1>
Expand All @@ -29,7 +64,7 @@
! HLFIR: acc.yield %[[ARG0]] : !fir.box<!fir.array<?xi32>>
! HLFIR: }

! CHECK-LABEL: acc.reduction.recipe @reduction_max_ref_Uxf32 : !fir.box<!fir.array<?xf32>> reduction_operator <max> init {
! CHECK-LABEL: acc.reduction.recipe @reduction_max_box_Uxf32 : !fir.box<!fir.array<?xf32>> reduction_operator <max> init {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.array<?xf32>>):
! CHECK: %[[INIT_VALUE:.*]] = arith.constant -1.401300e-45 : f32
! HLFIR: %[[C0:.*]] = arith.constant 0 : index
Expand Down Expand Up @@ -57,7 +92,7 @@
! CHECK: acc.yield %[[ARG0]] : !fir.box<!fir.array<?xf32>>
! CHECK: }

! CHECK-LABEL: acc.reduction.recipe @reduction_add_ref_Uxi32 : !fir.box<!fir.array<?xi32>> reduction_operator <add> init {
! CHECK-LABEL: acc.reduction.recipe @reduction_add_box_Uxi32 : !fir.box<!fir.array<?xi32>> reduction_operator <add> init {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.array<?xi32>>):
! HLFIR: %[[INIT_VALUE:.*]] = arith.constant 0 : i32
! HLFIR: %[[C0:.*]] = arith.constant 0 : index
Expand Down Expand Up @@ -1097,7 +1132,7 @@ subroutine acc_reduction_add_dynamic_extent_add(a)
! CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"})
! HLFIR: %[[DECLARG0:.*]]:2 = hlfir.declare %[[ARG0]]
! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.array<?xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<?xi32>> {name = "a"}
! HLFIR: acc.parallel reduction(@reduction_add_ref_Uxi32 -> %[[RED:.*]] : !fir.ref<!fir.array<?xi32>>)
! HLFIR: acc.parallel reduction(@reduction_add_box_Uxi32 -> %[[RED:.*]] : !fir.ref<!fir.array<?xi32>>)

subroutine acc_reduction_add_dynamic_extent_max(a)
real :: a(:)
Expand All @@ -1109,7 +1144,7 @@ subroutine acc_reduction_add_dynamic_extent_max(a)
! CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"})
! HLFIR: %[[DECLARG0:.*]]:2 = hlfir.declare %[[ARG0]]
! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
! HLFIR: acc.parallel reduction(@reduction_max_ref_Uxf32 -> %[[RED]] : !fir.ref<!fir.array<?xf32>>) {
! HLFIR: acc.parallel reduction(@reduction_max_box_Uxf32 -> %[[RED]] : !fir.ref<!fir.array<?xf32>>) {

subroutine acc_reduction_add_dynamic_extent_add_with_section(a)
integer :: a(:)
Expand All @@ -1123,4 +1158,34 @@ subroutine acc_reduction_add_dynamic_extent_add_with_section(a)
! HLFIR: %[[BOUND:.*]] = acc.bounds lowerbound(%c1{{.*}} : index) upperbound(%c3{{.*}} : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}} : index) {strideInBytes = true}
! HLFIR: %[[BOX_ADDR:.*]] = fir.box_addr %[[DECL]]#1 : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<?xi32>> {name = "a(2:4)"}
! HLFIR: acc.parallel reduction(@reduction_add_section_lb1.ub3_ref_Uxi32 -> %[[RED]] : !fir.ref<!fir.array<?xi32>>)
! HLFIR: acc.parallel reduction(@reduction_add_section_lb1.ub3_box_Uxi32 -> %[[RED]] : !fir.ref<!fir.array<?xi32>>)

subroutine acc_reduction_add_allocatable(a)
real, allocatable :: a(:)
!$acc parallel reduction(max:a)
!$acc end parallel
end subroutine

! CHECK-LABEL: func.func @_QPacc_reduction_add_allocatable(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {fir.bindc_name = "a"})
! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFacc_reduction_add_allocatableEa"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
! HLFIR: %[[BOX:.*]] = fir.load %[[DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
! HLFIR: %[[BOUND:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}}#0 : index) {strideInBytes = true}
! HLFIR: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.heap<!fir.array<?xf32>>
! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%6) -> !fir.heap<!fir.array<?xf32>> {name = "a"}
! HLFIR: acc.parallel reduction(@reduction_max_box_heap_Uxf32 -> %[[RED]] : !fir.heap<!fir.array<?xf32>>)

subroutine acc_reduction_add_pointer_array(a)
real, pointer :: a(:)
!$acc parallel reduction(max:a)
!$acc end parallel
end subroutine

! CHECK-LABEL: func.func @_QPacc_reduction_add_pointer_array(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a"})
! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFacc_reduction_add_pointer_arrayEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
! HLFIR: %[[BOX:.*]] = fir.load %[[DECL]]#1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
! HLFIR: %[[BOUND:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}}#0 : index) {strideInBytes = true}
! HLFIR: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
! HLFIR: acc.parallel reduction(@reduction_max_box_ptr_Uxf32 -> %[[RED]] : !fir.ptr<!fir.array<?xf32>>)

0 comments on commit 964a252

Please sign in to comment.