Skip to content

Commit

Permalink
[flang][NFC] use llvm.intr.stacksave/restore instead of opaque calls (#…
Browse files Browse the repository at this point in the history
…108562)

The new LLVM stack save/restore intrinsic operations are more convenient
than function calls because they do not add function declarations to the
module and therefore do not block the parallelisation of passes.
Furthermore they could be much more easily marked with memory effects
than function calls if that ever proved useful.

This builds on top of #107879.

Resolves #108016
  • Loading branch information
tblah authored Sep 16, 2024
1 parent 9548dbe commit 5aaf384
Show file tree
Hide file tree
Showing 25 changed files with 124 additions and 155 deletions.
21 changes: 21 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <utility>

namespace mlir {
class DataLayout;
class SymbolTable;
}

Expand Down Expand Up @@ -253,6 +254,15 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
mlir::ValueRange lenParams = {},
llvm::ArrayRef<mlir::NamedAttribute> attrs = {});

/// Create an LLVM stack save intrinsic op. Returns the saved stack pointer.
/// The stack address space is fetched from the data layout of the current
/// module.
mlir::Value genStackSave(mlir::Location loc);

/// Create an LLVM stack restore intrinsic op. stackPointer should be a value
/// previously returned from genStackSave.
void genStackRestore(mlir::Location loc, mlir::Value stackPointer);

/// Create a global value.
fir::GlobalOp createGlobal(mlir::Location loc, mlir::Type type,
llvm::StringRef name,
Expand Down Expand Up @@ -523,6 +533,9 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
setCommonAttributes(op);
}

/// Construct a data layout on demand and return it
mlir::DataLayout &getDataLayout();

private:
/// Set attributes (e.g. FastMathAttr) to \p op operation
/// based on the current attributes setting.
Expand All @@ -537,6 +550,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// fir::GlobalOp and func::FuncOp symbol table to speed-up
/// lookups.
mlir::SymbolTable *symbolTable = nullptr;

/// DataLayout constructed on demand. Access via getDataLayout().
/// Stored via a unique_ptr rather than an optional so as not to bloat this
/// class when most instances won't ever need a data layout.
std::unique_ptr<mlir::DataLayout> dataLayout = nullptr;
};

} // namespace fir
Expand Down Expand Up @@ -729,6 +747,9 @@ elideExtentsAlreadyInType(mlir::Type type, mlir::ValueRange shape);
llvm::SmallVector<mlir::Value>
elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams);

/// Get the address space which should be used for allocas
uint64_t getAllocaAddressSpace(mlir::DataLayout *dataLayout);

} // namespace fir::factory

#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
6 changes: 0 additions & 6 deletions flang/include/flang/Optimizer/Builder/LowLevelIntrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ mlir::func::FuncOp getLlvmGetRounding(FirOpBuilder &builder);
/// Get the `llvm.set.rounding` intrinsic.
mlir::func::FuncOp getLlvmSetRounding(FirOpBuilder &builder);

/// Get the `llvm.stacksave` intrinsic.
mlir::func::FuncOp getLlvmStackSave(FirOpBuilder &builder);

/// Get the `llvm.stackrestore` intrinsic.
mlir::func::FuncOp getLlvmStackRestore(FirOpBuilder &builder);

/// Get the `llvm.init.trampoline` intrinsic.
mlir::func::FuncOp getLlvmInitTrampoline(FirOpBuilder &builder);

Expand Down
1 change: 0 additions & 1 deletion flang/include/flang/Optimizer/Support/DataLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ void setMLIRDataLayoutFromAttributes(mlir::ModuleOp mlirModule,
/// std::nullopt.
std::optional<mlir::DataLayout>
getOrSetDataLayout(mlir::ModuleOp mlirModule, bool allowDefaultLayout = false);

} // namespace fir::support

#endif // FORTRAN_OPTIMIZER_SUPPORT_DATALAYOUT_H
11 changes: 3 additions & 8 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3257,15 +3257,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
const Fortran::parser::CharBlock &endPosition =
eval.getLastNestedEvaluation().position;
localSymbols.pushScope();
mlir::func::FuncOp stackSave = fir::factory::getLlvmStackSave(*builder);
mlir::func::FuncOp stackRestore =
fir::factory::getLlvmStackRestore(*builder);
mlir::Value stackPtr =
builder->create<fir::CallOp>(toLocation(), stackSave).getResult(0);
mlir::Value stackPtr = builder->genStackSave(toLocation());
mlir::Location endLoc = genLocation(endPosition);
stmtCtx.attachCleanup([=]() {
builder->create<fir::CallOp>(endLoc, stackRestore, stackPtr);
});
stmtCtx.attachCleanup(
[=]() { builder->genStackRestore(endLoc, stackPtr); });
Fortran::semantics::Scope &scope =
bridge.getSemanticsContext().FindScope(endPosition);
scopeBlockIdMap.try_emplace(&scope, ++blockId);
Expand Down
19 changes: 3 additions & 16 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,22 +368,9 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(

if (!extents.empty() || !lengths.empty()) {
auto *bldr = &converter.getFirOpBuilder();
auto stackSaveFn = fir::factory::getLlvmStackSave(builder);
auto stackSaveSymbol = bldr->getSymbolRefAttr(stackSaveFn.getName());
mlir::Value sp;
fir::CallOp call = bldr->create<fir::CallOp>(
loc, stackSaveSymbol, stackSaveFn.getFunctionType().getResults(),
mlir::ValueRange{});
if (call.getNumResults() != 0)
sp = call.getResult(0);
stmtCtx.attachCleanup([bldr, loc, sp]() {
auto stackRestoreFn = fir::factory::getLlvmStackRestore(*bldr);
auto stackRestoreSymbol =
bldr->getSymbolRefAttr(stackRestoreFn.getName());
bldr->create<fir::CallOp>(loc, stackRestoreSymbol,
stackRestoreFn.getFunctionType().getResults(),
mlir::ValueRange{sp});
});
mlir::Value sp = bldr->genStackSave(loc);
stmtCtx.attachCleanup(
[bldr, loc, sp]() { bldr->genStackRestore(loc, sp); });
}
mlir::Value temp =
builder.createTemporary(loc, type, ".result", extents, resultLengths);
Expand Down
28 changes: 28 additions & 0 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Support/FatalError.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Support/Utils.h"
Expand Down Expand Up @@ -328,6 +329,17 @@ mlir::Value fir::FirOpBuilder::createHeapTemporary(
name, dynamicLength, dynamicShape, attrs);
}

mlir::Value fir::FirOpBuilder::genStackSave(mlir::Location loc) {
mlir::Type voidPtr = mlir::LLVM::LLVMPointerType::get(
getContext(), fir::factory::getAllocaAddressSpace(&getDataLayout()));
return create<mlir::LLVM::StackSaveOp>(loc, voidPtr);
}

void fir::FirOpBuilder::genStackRestore(mlir::Location loc,
mlir::Value stackPointer) {
create<mlir::LLVM::StackRestoreOp>(loc, stackPointer);
}

/// Create a global variable in the (read-only) data section. A global variable
/// must have a unique name to identify and reference it.
fir::GlobalOp fir::FirOpBuilder::createGlobal(
Expand Down Expand Up @@ -791,6 +803,15 @@ void fir::FirOpBuilder::setFastMathFlags(
setFastMathFlags(arithFMF);
}

// Construction of an mlir::DataLayout is expensive so only do it on demand and
// memoise it in the builder instance
mlir::DataLayout &fir::FirOpBuilder::getDataLayout() {
if (dataLayout)
return *dataLayout;
dataLayout = std::make_unique<mlir::DataLayout>(getModule());
return *dataLayout;
}

//===--------------------------------------------------------------------===//
// ExtendedValue inquiry helper implementation
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -1664,3 +1685,10 @@ void fir::factory::setInternalLinkage(mlir::func::FuncOp func) {
mlir::LLVM::LinkageAttr::get(func->getContext(), internalLinkage);
func->setAttr("llvm.linkage", linkage);
}

uint64_t fir::factory::getAllocaAddressSpace(mlir::DataLayout *dataLayout) {
if (dataLayout)
if (mlir::Attribute addrSpace = dataLayout->getAllocaMemorySpace())
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}
19 changes: 0 additions & 19 deletions flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,25 +76,6 @@ fir::factory::getLlvmSetRounding(fir::FirOpBuilder &builder) {
funcTy);
}

mlir::func::FuncOp fir::factory::getLlvmStackSave(fir::FirOpBuilder &builder) {
// FIXME: This should query the target alloca address space
auto ptrTy = builder.getRefType(builder.getIntegerType(8));
auto funcTy =
mlir::FunctionType::get(builder.getContext(), std::nullopt, {ptrTy});
return builder.createFunction(builder.getUnknownLoc(), "llvm.stacksave.p0",
funcTy);
}

mlir::func::FuncOp
fir::factory::getLlvmStackRestore(fir::FirOpBuilder &builder) {
// FIXME: This should query the target alloca address space
auto ptrTy = builder.getRefType(builder.getIntegerType(8));
auto funcTy =
mlir::FunctionType::get(builder.getContext(), {ptrTy}, std::nullopt);
return builder.createFunction(builder.getUnknownLoc(), "llvm.stackrestore.p0",
funcTy);
}

mlir::func::FuncOp
fir::factory::getLlvmInitTrampoline(fir::FirOpBuilder &builder) {
auto ptrTy = builder.getRefType(builder.getIntegerType(8));
Expand Down
15 changes: 4 additions & 11 deletions flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,25 +1236,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {

inline void clearMembers() { setMembers(nullptr, nullptr, nullptr); }

uint64_t getAllocaAddressSpace() const {
if (dataLayout)
if (mlir::Attribute addrSpace = dataLayout->getAllocaMemorySpace())
return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}

// Inserts a call to llvm.stacksave at the current insertion
// point and the given location. Returns the call's result Value.
inline mlir::Value genStackSave(mlir::Location loc) {
mlir::Type voidPtr = mlir::LLVM::LLVMPointerType::get(
rewriter->getContext(), getAllocaAddressSpace());
return rewriter->create<mlir::LLVM::StackSaveOp>(loc, voidPtr);
fir::FirOpBuilder builder(*rewriter, getModule());
return builder.genStackSave(loc);
}

// Inserts a call to llvm.stackrestore at the current insertion
// point and the given location and argument.
inline void genStackRestore(mlir::Location loc, mlir::Value sp) {
rewriter->create<mlir::LLVM::StackRestoreOp>(loc, sp);
fir::FirOpBuilder builder(*rewriter, getModule());
return builder.genStackRestore(loc, sp);
}

fir::CodeGenSpecifics *specifics = nullptr;
Expand Down
20 changes: 2 additions & 18 deletions flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -734,28 +734,12 @@ void AllocMemConversion::insertStackSaveRestore(
auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder{rewriter, mod};

mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder);
mlir::SymbolRefAttr stackSaveSym =
builder.getSymbolRefAttr(stackSaveFn.getName());

builder.setInsertionPoint(oldAlloc);
mlir::Value sp =
builder
.create<fir::CallOp>(oldAlloc.getLoc(), stackSaveSym,
stackSaveFn.getFunctionType().getResults(),
mlir::ValueRange{})
.getResult(0);

mlir::func::FuncOp stackRestoreFn =
fir::factory::getLlvmStackRestore(builder);
mlir::SymbolRefAttr stackRestoreSym =
builder.getSymbolRefAttr(stackRestoreFn.getName());
mlir::Value sp = builder.genStackSave(oldAlloc.getLoc());

auto createStackRestoreCall = [&](mlir::Operation *user) {
builder.setInsertionPoint(user);
builder.create<fir::CallOp>(user->getLoc(), stackRestoreSym,
stackRestoreFn.getFunctionType().getResults(),
mlir::ValueRange{sp});
builder.genStackRestore(user->getLoc(), sp);
};

for (mlir::Operation *user : oldAlloc->getUsers()) {
Expand Down
21 changes: 4 additions & 17 deletions flang/lib/Optimizer/Transforms/StackReclaim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "flang/Common/Fortran.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Transforms/Passes.h"
Expand All @@ -31,34 +32,20 @@ class StackReclaimPass : public fir::impl::StackReclaimBase<StackReclaimPass> {
};
} // namespace

uint64_t getAllocaAddressSpace(Operation *op) {
mlir::ModuleOp module = mlir::dyn_cast_or_null<mlir::ModuleOp>(op);
if (!module)
module = op->getParentOfType<mlir::ModuleOp>();

if (mlir::Attribute addrSpace =
mlir::DataLayout(module).getAllocaMemorySpace())
return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}

void StackReclaimPass::runOnOperation() {
auto *op = getOperation();
auto *context = &getContext();
mlir::OpBuilder builder(context);
mlir::Type voidPtr =
mlir::LLVM::LLVMPointerType::get(context, getAllocaAddressSpace(op));
fir::FirOpBuilder builder(op, fir::getKindMapping(op));

op->walk([&](fir::DoLoopOp loopOp) {
mlir::Location loc = loopOp.getLoc();

if (!loopOp.getRegion().getOps<fir::AllocaOp>().empty()) {
builder.setInsertionPointToStart(&loopOp.getRegion().front());
auto stackSaveOp = builder.create<LLVM::StackSaveOp>(loc, voidPtr);
mlir::Value sp = builder.genStackSave(loc);

auto *terminator = loopOp.getRegion().back().getTerminator();
builder.setInsertionPoint(terminator);
builder.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
builder.genStackRestore(loc, sp);
}
});
}
2 changes: 1 addition & 1 deletion flang/test/HLFIR/order_assignments/where-scheduling.f90
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ end function f
!CHECK-NEXT: run 1 save : where/mask
!CHECK-NEXT: run 2 evaluate: where/region_assign1
!CHECK-LABEL: ------------ scheduling where in _QPonly_once ------------
!CHECK-NEXT: unknown effect: %{{[0-9]+}} = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
!CHECK-NEXT: unknown effect: %{{[0-9]+}} = llvm.intr.stacksave : !llvm.ptr
!CHECK-NEXT: run 1 save (w): where/mask
!CHECK-NEXT: run 2 evaluate: where/region_assign1
!CHECK-NEXT: run 3 evaluate: where/region_assign2
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/HLFIR/block_bindc_pocs.f90
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ subroutine test_proc() bind(C)
end subroutine test_proc
end interface
end module m
!CHECK-DAG: %[[S0:.*]] = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
!CHECK-DAG: %[[S0:.*]] = llvm.intr.stacksave : !llvm.ptr
!CHECK-DAG: fir.call @test_proc() proc_attrs<bind_c> fastmath<contract> : () -> ()
!CHECK-DAG: fir.call @llvm.stackrestore.p0(%[[S0]]) fastmath<contract> : (!fir.ref<i8>) -> ()
!CHECK-DAG: llvm.intr.stackrestore %[[S0]] : !llvm.ptr
!CHECK-DAG: func.func private @test_proc() attributes {fir.bindc_name = "test_proc"}
subroutine test
BLOCK
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/HLFIR/elemental-array-ops.f90
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ end subroutine char_return
! CHECK: %[[VAL_23:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_24:.*]] = arith.cmpi sgt, %[[VAL_22]], %[[VAL_23]] : index
! CHECK: %[[VAL_25:.*]] = arith.select %[[VAL_24]], %[[VAL_22]], %[[VAL_23]] : index
! CHECK: %[[VAL_26:.*]] = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
! CHECK: %[[VAL_26:.*]] = llvm.intr.stacksave : !llvm.ptr
! CHECK: %[[VAL_27:.*]] = fir.call @_QPcallee(%[[VAL_2]], %[[VAL_25]], %[[VAL_20]]) fastmath<contract> : (!fir.ref<!fir.char<1,3>>, index, !fir.boxchar<1>) -> !fir.boxchar<1>
! CHECK: %[[VAL_28:.*]]:2 = hlfir.declare %[[VAL_2]] typeparams %[[VAL_25]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.char<1,3>>, index) -> (!fir.ref<!fir.char<1,3>>, !fir.ref<!fir.char<1,3>>)
! CHECK: %[[MustFree:.*]] = arith.constant false
! CHECK: %[[ResultTemp:.*]] = hlfir.as_expr %[[VAL_28]]#0 move %[[MustFree]] : (!fir.ref<!fir.char<1,3>>, i1) -> !hlfir.expr<!fir.char<1,3>>
! CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_26]]) fastmath<contract> : (!fir.ref<i8>) -> ()
! CHECK: llvm.intr.stackrestore %[[VAL_26]] : !llvm.ptr
! CHECK: hlfir.yield_element %[[ResultTemp]] : !hlfir.expr<!fir.char<1,3>>
! CHECK: }
! CHECK: %[[VAL_29:.*]] = arith.constant 0 : index
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/HLFIR/proc-pointer-comp-pass.f90
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,6 @@ subroutine test5(x)
! CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_8:.*]] = arith.cmpi sgt, %[[VAL_6]], %[[VAL_7]] : index
! CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_6]], %[[VAL_7]] : index
! CHECK: %[[VAL_10:.*]] = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
! CHECK: %[[VAL_10:.*]] = llvm.intr.stacksave : !llvm.ptr
! CHECK: %[[VAL_11:.*]] = fir.box_addr %[[VAL_4]] : (!fir.boxproc<(!fir.ref<!fir.char<1,4>>, index, !fir.ref<!fir.type<_QMmTt3{c:!fir.char<1,4>,p:!fir.boxproc<(!fir.ref<!fir.char<1,4>>, index, !fir.ref<!fir.type<_QMmTt3>>) -> !fir.boxchar<1>>}>>) -> !fir.boxchar<1>>) -> ((!fir.ref<!fir.char<1,4>>, index, !fir.ref<!fir.type<_QMmTt3{c:!fir.char<1,4>,p:!fir.boxproc<(!fir.ref<!fir.char<1,4>>, index, !fir.ref<!fir.type<_QMmTt3>>) -> !fir.boxchar<1>>}>>) -> !fir.boxchar<1>)
! CHECK: %[[VAL_12:.*]] = fir.call %[[VAL_11]](%[[VAL_1]], %[[VAL_9]], %[[VAL_2]]#1) fastmath<contract> : (!fir.ref<!fir.char<1,4>>, index, !fir.ref<!fir.type<_QMmTt3{c:!fir.char<1,4>,p:!fir.boxproc<(!fir.ref<!fir.char<1,4>>, index, !fir.ref<!fir.type<_QMmTt3>>) -> !fir.boxchar<1>>}>>) -> !fir.boxchar<1>
Loading

0 comments on commit 5aaf384

Please sign in to comment.