From 2a8ba2d72b4bec5ad6d13438b2fa57bc8cbe313a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 23 Oct 2022 15:15:42 -0400 Subject: [PATCH] Custom forward zero (#913) * Custom forward zero * Fix activity analysis of agg --- enzyme/Enzyme/ActivityAnalysis.cpp | 2 +- enzyme/Enzyme/EnzymeLogic.cpp | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 395818f43957a..c9685ca13dbfd 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1054,7 +1054,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { // This value is certainly an integer (and only and integer, not a pointer or // float). Therefore its value is constant - if (TR.intType(1, Val, /*errIfNotFound*/ false).isIntegral()) { + if (TR.query(Val)[{-1}] == BaseType::Integer) { if (EnzymePrintActivity) llvm::errs() << " Value const as integral " << (int)directions << " " << *Val << " " diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 69b419746f746..efe46b21eb035 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2828,8 +2828,11 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB, } else if (!gutils->isConstantValue(ret)) { toret = gutils->diffe(ret, nBuilder); } else { + IRBuilder<> eB(gutils->inversionAllocs); Type *retTy = gutils->getShadowType(ret->getType()); - toret = Constant::getNullValue(retTy); + auto al = eB.CreateAlloca(retTy); + ZeroMemory(eB, retTy, al, /*isTape*/ false); + toret = nBuilder.CreateLoad(al); } break;