diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 5b221f0ab41c..3e7137b617e7 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -600,6 +600,10 @@ void EnzymeSetMustCache(LLVMValueRef inst1) { I1->setMetadata("enzyme_mustcache", MDNode::get(I1->getContext(), {})); } +void EnzymeReplaceFunctionImplementation(LLVMModuleRef M) { + ReplaceFunctionImplementation(*unwrap(M)); +} + #if LLVM_VERSION_MAJOR >= 9 void EnzymeAddAttributorLegacyPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createAttributorLegacyPass()); diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index ae4c12e6ff2c..a2cb4b62a254 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2070,42 +2070,16 @@ void SelectOptimization(Function *F) { } } } -void PreProcessCache::optimizeIntermediate(Function *F) { - PromotePass().run(*F, FAM); -#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG) - GVNPass().run(*F, FAM); -#else - GVN().run(*F, FAM); -#endif -#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG) - SROAPass().run(*F, FAM); -#else - SROA().run(*F, FAM); -#endif - if (EnzymeSelectOpt) { -#if LLVM_VERSION_MAJOR >= 12 - SimplifyCFGOptions scfgo; -#else - SimplifyCFGOptions scfgo( - /*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false, - /*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true, - /*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr); -#endif - SimplifyCFGPass(scfgo).run(*F, FAM); - CorrelatedValuePropagationPass().run(*F, FAM); - SelectOptimization(F); - } - // EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM); - - for (Function &Impl : *F->getParent()) { +void ReplaceFunctionImplementation(Module &M) { + for (Function &Impl : M) { for (auto attr : {"implements", "implements2"}) { if (!Impl.hasFnAttribute(attr)) continue; const Attribute &A = Impl.getFnAttribute(attr); const StringRef SpecificationName = A.getValueAsString(); - Function *Specification = F->getParent()->getFunction(SpecificationName); + Function *Specification = M.getFunction(SpecificationName); if (!Specification) { LLVM_DEBUG(dbgs() << "Found implementation '" << Impl.getName() << "' but no matching specification with name '" @@ -2139,10 +2113,41 @@ void PreProcessCache::optimizeIntermediate(Function *F) { } } } +} + +void PreProcessCache::optimizeIntermediate(Function *F) { + PromotePass().run(*F, FAM); +#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG) + GVNPass().run(*F, FAM); +#else + GVN().run(*F, FAM); +#endif +#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG) + SROAPass().run(*F, FAM); +#else + SROA().run(*F, FAM); +#endif + + if (EnzymeSelectOpt) { +#if LLVM_VERSION_MAJOR >= 12 + SimplifyCFGOptions scfgo; +#else + SimplifyCFGOptions scfgo( + /*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false, + /*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true, + /*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr); +#endif + SimplifyCFGPass(scfgo).run(*F, FAM); + CorrelatedValuePropagationPass().run(*F, FAM); + SelectOptimization(F); + } + // EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM); if (EnzymeCoalese) CoaleseTrivialMallocs(*F, FAM.getResult(*F)); + ReplaceFunctionImplementation(*F->getParent()); + PreservedAnalyses PA; FAM.invalidate(*F, PA); // TODO actually run post optimizations. diff --git a/enzyme/Enzyme/FunctionUtils.h b/enzyme/Enzyme/FunctionUtils.h index 126c0334a1d6..2c088aa7b4ae 100644 --- a/enzyme/Enzyme/FunctionUtils.h +++ b/enzyme/Enzyme/FunctionUtils.h @@ -346,11 +346,14 @@ static inline void calculateUnusedStores( } } +void ReplaceFunctionImplementation(llvm::Module &M); + /// Is the use of value val as an argument of call CI potentially captured bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val); -#endif llvm::FunctionType *getFunctionTypeForClone( llvm::FunctionType *FTy, DerivativeMode mode, unsigned width, llvm::Type *additionalArg, llvm::ArrayRef constant_args, bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE returnType); + +#endif