Skip to content

Commit

Permalink
Expose function replacement (rust-lang#741)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jul 18, 2022
1 parent 3dfdabc commit e496dfa
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 30 deletions.
4 changes: 4 additions & 0 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,10 @@ void EnzymeSetMustCache(LLVMValueRef inst1) {
I1->setMetadata("enzyme_mustcache", MDNode::get(I1->getContext(), {}));
}

void EnzymeReplaceFunctionImplementation(LLVMModuleRef M) {
ReplaceFunctionImplementation(*unwrap(M));
}

#if LLVM_VERSION_MAJOR >= 9
void EnzymeAddAttributorLegacyPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createAttributorLegacyPass());
Expand Down
63 changes: 34 additions & 29 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2070,42 +2070,16 @@ void SelectOptimization(Function *F) {
}
}
}
void PreProcessCache::optimizeIntermediate(Function *F) {
PromotePass().run(*F, FAM);
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
GVNPass().run(*F, FAM);
#else
GVN().run(*F, FAM);
#endif
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
SROAPass().run(*F, FAM);
#else
SROA().run(*F, FAM);
#endif

if (EnzymeSelectOpt) {
#if LLVM_VERSION_MAJOR >= 12
SimplifyCFGOptions scfgo;
#else
SimplifyCFGOptions scfgo(
/*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false,
/*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true,
/*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr);
#endif
SimplifyCFGPass(scfgo).run(*F, FAM);
CorrelatedValuePropagationPass().run(*F, FAM);
SelectOptimization(F);
}
// EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM);

for (Function &Impl : *F->getParent()) {
void ReplaceFunctionImplementation(Module &M) {
for (Function &Impl : M) {
for (auto attr : {"implements", "implements2"}) {
if (!Impl.hasFnAttribute(attr))
continue;
const Attribute &A = Impl.getFnAttribute(attr);

const StringRef SpecificationName = A.getValueAsString();
Function *Specification = F->getParent()->getFunction(SpecificationName);
Function *Specification = M.getFunction(SpecificationName);
if (!Specification) {
LLVM_DEBUG(dbgs() << "Found implementation '" << Impl.getName()
<< "' but no matching specification with name '"
Expand Down Expand Up @@ -2139,10 +2113,41 @@ void PreProcessCache::optimizeIntermediate(Function *F) {
}
}
}
}

void PreProcessCache::optimizeIntermediate(Function *F) {
PromotePass().run(*F, FAM);
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
GVNPass().run(*F, FAM);
#else
GVN().run(*F, FAM);
#endif
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
SROAPass().run(*F, FAM);
#else
SROA().run(*F, FAM);
#endif

if (EnzymeSelectOpt) {
#if LLVM_VERSION_MAJOR >= 12
SimplifyCFGOptions scfgo;
#else
SimplifyCFGOptions scfgo(
/*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false,
/*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true,
/*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr);
#endif
SimplifyCFGPass(scfgo).run(*F, FAM);
CorrelatedValuePropagationPass().run(*F, FAM);
SelectOptimization(F);
}
// EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM);

if (EnzymeCoalese)
CoaleseTrivialMallocs(*F, FAM.getResult<DominatorTreeAnalysis>(*F));

ReplaceFunctionImplementation(*F->getParent());

PreservedAnalyses PA;
FAM.invalidate(*F, PA);
// TODO actually run post optimizations.
Expand Down
5 changes: 4 additions & 1 deletion enzyme/Enzyme/FunctionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,14 @@ static inline void calculateUnusedStores(
}
}

void ReplaceFunctionImplementation(llvm::Module &M);

/// Is the use of value val as an argument of call CI potentially captured
bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val);
#endif

llvm::FunctionType *getFunctionTypeForClone(
llvm::FunctionType *FTy, DerivativeMode mode, unsigned width,
llvm::Type *additionalArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE returnType);

#endif

0 comments on commit e496dfa

Please sign in to comment.