Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][openacc][NFC] Simplify lowering of recipe #68836

Merged
merged 1 commit into from
Oct 16, 2023

Conversation

clementval
Copy link
Contributor

Refactor some of the lowering in the reduction and firstprivate recipe to avoid duplicated code.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir openacc labels Oct 11, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 11, 2023

@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Refactor some of the lowering in the reduction and firstprivate recipe to avoid duplicated code.


Full diff: https://github.com/llvm/llvm-project/pull/68836.diff

1 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+74-101)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 61a1b9fd86717cb..e09266121cdb997 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -463,7 +463,7 @@ bool isConstantBound(mlir::acc::DataBoundsOp &op) {
 }
 
 /// Return true iff all the bounds are expressed with constant values.
-bool areAllBoundConstant(llvm::SmallVector<mlir::Value> &bounds) {
+bool areAllBoundConstant(const llvm::SmallVector<mlir::Value> &bounds) {
   for (auto bound : bounds) {
     auto dataBound =
         mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
@@ -474,27 +474,6 @@ bool areAllBoundConstant(llvm::SmallVector<mlir::Value> &bounds) {
   return true;
 }
 
-static fir::ShapeOp
-genShapeFromBounds(mlir::Location loc, fir::FirOpBuilder &builder,
-                   const llvm::SmallVector<mlir::Value> &args) {
-  assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
-  llvm::SmallVector<mlir::Value> extents;
-  mlir::Type idxTy = builder.getIndexType();
-  mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
-  mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
-  for (unsigned i = 0; i < args.size(); i += 3) {
-    mlir::Value s1 =
-        builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
-    mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
-    mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
-    mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
-        loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
-    mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
-    extents.push_back(ext);
-  }
-  return builder.create<fir::ShapeOp>(loc, extents);
-}
-
 static llvm::SmallVector<mlir::Value>
 genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
                   mlir::acc::DataBoundsOp &dataBound) {
@@ -520,6 +499,63 @@ genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
   return {lb, ub, step};
 }
 
+static fir::ShapeOp genShapeFromBoundsOrArgs(
+    mlir::Location loc, fir::FirOpBuilder &builder, fir::SequenceType seqTy,
+    const llvm::SmallVector<mlir::Value> &bounds, mlir::ValueRange arguments) {
+  llvm::SmallVector<mlir::Value> args;
+  if (areAllBoundConstant(bounds)) {
+    for (auto bound : llvm::reverse(bounds)) {
+      auto dataBound =
+          mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
+      args.append(genConstantBounds(builder, loc, dataBound));
+    }
+  } else {
+    assert(((arguments.size() - 2) / 3 == seqTy.getDimension()) &&
+           "Expect 3 block arguments per dimension");
+    for (auto arg : arguments.drop_front(2))
+      args.push_back(arg);
+  }
+
+  assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
+  llvm::SmallVector<mlir::Value> extents;
+  mlir::Type idxTy = builder.getIndexType();
+  mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
+  mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
+  for (unsigned i = 0; i < args.size(); i += 3) {
+    mlir::Value s1 =
+        builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
+    mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
+    mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
+    mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
+        loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
+    mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
+    extents.push_back(ext);
+  }
+  return builder.create<fir::ShapeOp>(loc, extents);
+}
+
+static hlfir::DesignateOp::Subscripts
+getSubscriptsFromArgs(mlir::ValueRange args) {
+  hlfir::DesignateOp::Subscripts triplets;
+  for (unsigned i = 2; i < args.size(); i += 3)
+    triplets.emplace_back(
+        hlfir::DesignateOp::Triplet{args[i], args[i + 1], args[i + 2]});
+  return triplets;
+}
+
+static hlfir::Entity genDesignateWithTriplets(
+    fir::FirOpBuilder &builder, mlir::Location loc, hlfir::Entity &entity,
+    hlfir::DesignateOp::Subscripts &triplets, mlir::Value shape) {
+  llvm::SmallVector<mlir::Value> lenParams;
+  hlfir::genLengthParameters(loc, builder, entity, lenParams);
+  auto designate = builder.create<hlfir::DesignateOp>(
+      loc, entity.getBase().getType(), entity, /*component=*/"",
+      /*componentShape=*/mlir::Value{}, triplets,
+      /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape,
+      lenParams);
+  return hlfir::Entity{designate.getResult()};
+}
+
 mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
     mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
     mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) {
@@ -600,47 +636,16 @@ mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
     if (!seqTy)
       TODO(loc, "Unsupported boxed type in OpenACC firstprivate");
 
-    if (allConstantBound) {
-      for (auto bound : llvm::reverse(bounds)) {
-        auto dataBound =
-            mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
-        tripletArgs.append(genConstantBounds(firBuilder, loc, dataBound));
-      }
-    } else {
-      assert(((recipe.getCopyRegion().getArguments().size() - 2) / 3 ==
-              seqTy.getDimension()) &&
-             "Expect 3 block arguments per dimension");
-      for (auto arg : recipe.getCopyRegion().getArguments().drop_front(2))
-        tripletArgs.push_back(arg);
-    }
-    auto shape = genShapeFromBounds(loc, firBuilder, tripletArgs);
-    hlfir::DesignateOp::Subscripts triplets;
-    for (unsigned i = 2; i < recipe.getCopyRegion().getArguments().size();
-         i += 3)
-      triplets.emplace_back(hlfir::DesignateOp::Triplet{
-          recipe.getCopyRegion().getArgument(i),
-          recipe.getCopyRegion().getArgument(i + 1),
-          recipe.getCopyRegion().getArgument(i + 2)});
-
-    llvm::SmallVector<mlir::Value> lenParamsLeft;
+    auto shape = genShapeFromBoundsOrArgs(
+        loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
+    hlfir::DesignateOp::Subscripts triplets =
+        getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
     auto leftEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(0)};
-    hlfir::genLengthParameters(loc, firBuilder, leftEntity, lenParamsLeft);
-    auto leftDesignate = firBuilder.create<hlfir::DesignateOp>(
-        loc, leftEntity.getBase().getType(), leftEntity, /*component=*/"",
-        /*componentShape=*/mlir::Value{}, triplets,
-        /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
-        shape, lenParamsLeft);
-    auto left = hlfir::Entity{leftDesignate.getResult()};
-
-    llvm::SmallVector<mlir::Value> lenParamsRight;
+    auto left =
+        genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
     auto rightEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(1)};
-    hlfir::genLengthParameters(loc, firBuilder, rightEntity, lenParamsRight);
-    auto rightDesignate = firBuilder.create<hlfir::DesignateOp>(
-        loc, rightEntity.getBase().getType(), rightEntity, /*component=*/"",
-        /*componentShape=*/mlir::Value{}, triplets,
-        /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
-        shape, lenParamsRight);
-    auto right = hlfir::Entity{rightDesignate.getResult()};
+    auto right =
+        genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
     firBuilder.create<hlfir::AssignOp>(loc, left, right);
   }
 
@@ -1110,48 +1115,16 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
     if (!seqTy)
       TODO(loc, "Unsupported boxed type in OpenACC reduction");
 
-    if (allConstantBound) {
-      for (auto bound : llvm::reverse(bounds)) {
-        auto dataBound =
-            mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
-        tripletArgs.append(genConstantBounds(builder, loc, dataBound));
-      }
-    } else {
-      assert(((recipe.getCombinerRegion().getArguments().size() - 2) / 3 ==
-              seqTy.getDimension()) &&
-             "Expect 3 block arguments per dimension");
-      for (auto arg : recipe.getCombinerRegion().getArguments().drop_front(2))
-        tripletArgs.push_back(arg);
-    }
-    auto shape = genShapeFromBounds(loc, builder, tripletArgs);
-
-    hlfir::DesignateOp::Subscripts triplets;
-    for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
-         i += 3)
-      triplets.emplace_back(hlfir::DesignateOp::Triplet{
-          recipe.getCombinerRegion().getArgument(i),
-          recipe.getCombinerRegion().getArgument(i + 1),
-          recipe.getCombinerRegion().getArgument(i + 2)});
-
-    llvm::SmallVector<mlir::Value> lenParamsLeft;
+    auto shape = genShapeFromBoundsOrArgs(
+        loc, builder, seqTy, bounds, recipe.getCombinerRegion().getArguments());
+    hlfir::DesignateOp::Subscripts triplets =
+        getSubscriptsFromArgs(recipe.getCombinerRegion().getArguments());
     auto leftEntity = hlfir::Entity{value1};
-    hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft);
-    auto leftDesignate = builder.create<hlfir::DesignateOp>(
-        loc, value1.getType(), leftEntity, /*component=*/"",
-        /*componentShape=*/mlir::Value{}, triplets,
-        /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
-        shape, lenParamsLeft);
-    auto left = hlfir::Entity{leftDesignate.getResult()};
-
-    llvm::SmallVector<mlir::Value> lenParamsRight;
+    auto left =
+        genDesignateWithTriplets(builder, loc, leftEntity, triplets, shape);
     auto rightEntity = hlfir::Entity{value2};
-    hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsRight);
-    auto rightDesignate = builder.create<hlfir::DesignateOp>(
-        loc, value2.getType(), rightEntity, /*component=*/"",
-        /*componentShape=*/mlir::Value{}, triplets,
-        /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
-        shape, lenParamsRight);
-    auto right = hlfir::Entity{rightDesignate.getResult()};
+    auto right =
+        genDesignateWithTriplets(builder, loc, rightEntity, triplets, shape);
 
     llvm::SmallVector<mlir::Value, 1> typeParams;
     auto genKernel = [&builder, &loc, op, seqTy, &left, &right](

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@clementval clementval merged commit 468d3b1 into llvm:main Oct 16, 2023
@clementval clementval deleted the acc_lower_simplify branch October 16, 2023 16:35
clementval added a commit that referenced this pull request Oct 16, 2023
…ecipe (#69026)

Add lowering support for array with dynamic extents in the firstprivate
recipe. Generalize the lowering so static shaped arrays and array with
dynamic extents use the same path.

Some cleaning code is taken from #68836 that is not landed yet.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants