From cb7900e45200b35ae2c349ebcc3588651997e9d2 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 30 Aug 2018 14:49:56 -0400 Subject: [PATCH] Vector lowering improvents in GC placement Support for vectors of tracked pointer was incomplete in the GC placement pass. Try to fix as many cases as possible and add some tests. A refactor to make all of this nicer (vectors weren't originally part of the implementation might be good), but for now, let's get it correct first. Fixes #28536 (cherry picked from commit b1dac9fdb037adbebf3d781b25d4bf4b64be2486) --- src/llvm-late-gc-lowering.cpp | 236 +++++++++++++++++++++++----------- test/llvmpasses/gcroots.ll | 105 +++++++++++++++ 2 files changed, 266 insertions(+), 75 deletions(-) diff --git a/src/llvm-late-gc-lowering.cpp b/src/llvm-late-gc-lowering.cpp index d83f5785f9cb1..902e110cf05cf 100644 --- a/src/llvm-late-gc-lowering.cpp +++ b/src/llvm-late-gc-lowering.cpp @@ -349,8 +349,8 @@ struct LateLowerGCFrame: public FunctionPass { NoteUse(S, BBS, V, BBS.UpExposedUses); } Value *MaybeExtractUnion(std::pair Val, Instruction *InsertBefore); - int LiftPhi(State &S, PHINode *Phi); - int LiftSelect(State &S, SelectInst *SI); + void LiftPhi(State &S, PHINode *Phi, SmallVector &PHINumbers); + bool LiftSelect(State &S, SelectInst *SI); int Number(State &S, Value *V); std::vector NumberVector(State &S, Value *Vec); int NumberBase(State &S, Value *V, Value *Base); @@ -383,7 +383,10 @@ struct LateLowerGCFrame: public FunctionPass { }; static unsigned getValueAddrSpace(Value *V) { - return cast(V->getType())->getAddressSpace(); + Type *Ty = V->getType(); + if (isa(Ty)) + Ty = cast(V->getType())->getElementType(); + return cast(Ty)->getAddressSpace(); } static bool isSpecialPtr(Type *Ty) { @@ -508,42 +511,108 @@ Value *LateLowerGCFrame::MaybeExtractUnion(std::pair Val, Instructio return Val.first; } -int LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) { - Value *TrueBase = MaybeExtractUnion(FindBaseValue(S, SI->getTrueValue(), false), SI); - Value *FalseBase = MaybeExtractUnion(FindBaseValue(S, SI->getFalseValue(), false), SI); - if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked) - TrueBase = ConstantPointerNull::get(cast(FalseBase->getType())); - if (getValueAddrSpace(FalseBase) != AddressSpace::Tracked) - FalseBase = ConstantPointerNull::get(cast(TrueBase->getType())); - if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked) - return -1; - Value *SelectBase = SelectInst::Create(SI->getCondition(), - TrueBase, FalseBase, "gclift", SI); - int Number = ++S.MaxPtrNumber; - S.PtrNumbering[SelectBase] = S.AllPtrNumbering[SelectBase] = - S.AllPtrNumbering[SI] = Number; - S.ReversePtrNumbering[Number] = SelectBase; - return Number; +static Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint) +{ + Value *Val = S.ReversePtrNumbering[Num]; + if (isSpecialPtrVec(Val->getType())) { + const std::vector &AllNums = S.AllVectorNumbering[Val]; + unsigned Idx = 0; + for (; Idx < AllNums.size(); ++Idx) { + if ((unsigned)AllNums[Idx] == Num) + break; + } + Val = ExtractElementInst::Create(Val, ConstantInt::get( + Type::getInt32Ty(Val->getContext()), Idx), "", InsertionPoint); + } + return Val; +} + +bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) { + if (isSpecialPtrVec(SI->getType())) { + VectorType *VT = cast(SI->getType()); + std::vector TrueNumbers = NumberVector(S, SI->getTrueValue()); + std::vector FalseNumbers = NumberVector(S, SI->getFalseValue()); + std::vector Numbers; + for (unsigned i = 0; i < VT->getNumElements(); ++i) { + SelectInst *LSI = SelectInst::Create(SI->getCondition(), + TrueNumbers[i] < 0 ? + ConstantPointerNull::get(cast(T_prjlvalue)) : + GetPtrForNumber(S, TrueNumbers[i], SI), + FalseNumbers[i] < 0 ? + ConstantPointerNull::get(cast(T_prjlvalue)) : + GetPtrForNumber(S, FalseNumbers[i], SI), + "gclift", SI); + int Number = ++S.MaxPtrNumber; + Numbers.push_back(Number); + S.PtrNumbering[LSI] = S.AllPtrNumbering[LSI] = Number; + S.ReversePtrNumbering[Number] = LSI; + } + S.AllVectorNumbering[SI] = Numbers; + } else { + Value *TrueBase = MaybeExtractUnion(FindBaseValue(S, SI->getTrueValue(), false), SI); + Value *FalseBase = MaybeExtractUnion(FindBaseValue(S, SI->getFalseValue(), false), SI); + if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked) + TrueBase = ConstantPointerNull::get(cast(FalseBase->getType())); + if (getValueAddrSpace(FalseBase) != AddressSpace::Tracked) + FalseBase = ConstantPointerNull::get(cast(TrueBase->getType())); + if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked) + return false; + Value *SelectBase = SelectInst::Create(SI->getCondition(), + TrueBase, FalseBase, "gclift", SI); + int Number = ++S.MaxPtrNumber; + S.PtrNumbering[SelectBase] = S.AllPtrNumbering[SelectBase] = + S.AllPtrNumbering[SI] = Number; + S.ReversePtrNumbering[Number] = SelectBase; + } + return true; } -int LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi) +void LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi, SmallVector &PHINumbers) { - PHINode *lift = PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi); - for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) { - Value *Incoming = Phi->getIncomingValue(i); - Value *Base = MaybeExtractUnion(FindBaseValue(S, Incoming, false), - Phi->getIncomingBlock(i)->getTerminator()); - if (getValueAddrSpace(Base) != AddressSpace::Tracked) - Base = ConstantPointerNull::get(cast(T_prjlvalue)); - if (Base->getType() != T_prjlvalue) - Base = new BitCastInst(Base, T_prjlvalue, "", Phi->getIncomingBlock(i)->getTerminator()); - lift->addIncoming(Base, Phi->getIncomingBlock(i)); + if (isSpecialPtrVec(Phi->getType())) { + VectorType *VT = cast(Phi->getType()); + std::vector lifted; + for (unsigned i = 0; i < VT->getNumElements(); ++i) { + lifted.push_back(PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi)); + } + for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) { + std::vector Numbers = NumberVector(S, Phi->getIncomingValue(i)); + BasicBlock *IncomingBB = Phi->getIncomingBlock(i); + Instruction *Terminator = IncomingBB->getTerminator(); + for (unsigned i = 0; i < VT->getNumElements(); ++i) { + if (Numbers[i] < 0) + lifted[i]->addIncoming(ConstantPointerNull::get(cast(T_prjlvalue)), IncomingBB); + else + lifted[i]->addIncoming(GetPtrForNumber(S, Numbers[i], Terminator), IncomingBB); + } + } + std::vector Numbers; + for (unsigned i = 0; i < VT->getNumElements(); ++i) { + int Number = ++S.MaxPtrNumber; + PHINumbers.push_back(Number); + Numbers.push_back(Number); + S.PtrNumbering[lifted[i]] = S.AllPtrNumbering[lifted[i]] = Number; + S.ReversePtrNumbering[Number] = lifted[i]; + } + S.AllVectorNumbering[Phi] = Numbers; + } else { + PHINode *lift = PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi); + for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) { + Value *Incoming = Phi->getIncomingValue(i); + Value *Base = MaybeExtractUnion(FindBaseValue(S, Incoming, false), + Phi->getIncomingBlock(i)->getTerminator()); + if (getValueAddrSpace(Base) != AddressSpace::Tracked) + Base = ConstantPointerNull::get(cast(T_prjlvalue)); + if (Base->getType() != T_prjlvalue) + Base = new BitCastInst(Base, T_prjlvalue, "", Phi->getIncomingBlock(i)->getTerminator()); + lift->addIncoming(Base, Phi->getIncomingBlock(i)); + } + int Number = ++S.MaxPtrNumber; + PHINumbers.push_back(Number); + S.PtrNumbering[lift] = S.AllPtrNumbering[lift] = + S.AllPtrNumbering[Phi] = Number; + S.ReversePtrNumbering[Number] = lift; } - int Number = ++S.MaxPtrNumber; - S.PtrNumbering[lift] = S.AllPtrNumbering[lift] = - S.AllPtrNumbering[Phi] = Number; - S.ReversePtrNumbering[Number] = lift; - return Number; } int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV) @@ -566,12 +635,14 @@ int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV) // input IR) Number = -1; } else if (isa(CurrentV) && !isUnion && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { - int Number = LiftSelect(S, cast(CurrentV)); - S.AllPtrNumbering[V] = Number; + Number = -1; + if (LiftSelect(S, cast(CurrentV))) + Number = S.AllPtrNumbering[V] = S.AllPtrNumbering.at(CurrentV); return Number; } else if (isa(CurrentV) && !isUnion && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { - int Number = LiftPhi(S, cast(CurrentV)); - S.AllPtrNumbering[V] = Number; + SmallVector PHINumbers; + LiftPhi(S, cast(CurrentV), PHINumbers); + Number = S.AllPtrNumbering[V] = S.AllPtrNumbering.at(CurrentV); return Number; } else if (isa(CurrentV) && !isUnion) { assert(false && "TODO: Extract"); @@ -630,7 +701,15 @@ std::vector LateLowerGCFrame::NumberVectorBase(State &S, Value *CurrentV) { Numbers = NumberVectorBase(S, IEI->getOperand(0)); int ElNumber = Number(S, IEI->getOperand(1)); Numbers[idx] = ElNumber; - } else if (isa(CurrentV) || isa(CurrentV) || isa(CurrentV)) { + } else if (isa(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { + LiftSelect(S, cast(CurrentV)); + Numbers = S.AllVectorNumbering[CurrentV]; + } else if (isa(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { + SmallVector PHINumbers; + LiftPhi(S, cast(CurrentV), PHINumbers); + Numbers = S.AllVectorNumbering[CurrentV]; + } else if (isa(CurrentV) || isa(CurrentV) || isa(CurrentV) || + isa(CurrentV)) { // This is simple, we can just number them sequentially for (unsigned i = 0; i < cast(CurrentV->getType())->getNumElements(); ++i) { int Num = ++S.MaxPtrNumber; @@ -638,7 +717,7 @@ std::vector LateLowerGCFrame::NumberVectorBase(State &S, Value *CurrentV) { S.ReversePtrNumbering[Num] = CurrentV; } } else { - assert(false && "Unexpected vector generating operating"); + assert(false && "Unexpected vector generating operation"); } S.AllVectorNumbering[CurrentV] = Numbers; return Numbers; @@ -1148,40 +1227,63 @@ State LateLowerGCFrame::LocalScan(Function &F) { NoteOperandUses(S, BBS, I, BBS.UpExposedUsesUnrooted); } else if (SelectInst *SI = dyn_cast(&I)) { // We need to insert an extra select for the GC root - if (!isSpecialPtr(SI->getType()) && !isUnionRep(SI->getType())) + if (!isSpecialPtr(SI->getType()) && !isSpecialPtrVec(SI->getType()) && + !isUnionRep(SI->getType())) continue; if (!isUnionRep(SI->getType()) && getValueAddrSpace(SI) != AddressSpace::Tracked) { - if (S.AllPtrNumbering.find(SI) != S.AllPtrNumbering.end()) + if (isSpecialPtrVec(SI->getType()) ? + S.AllVectorNumbering.find(SI) != S.AllVectorNumbering.end() : + S.AllPtrNumbering.find(SI) != S.AllPtrNumbering.end()) continue; - auto Num = LiftSelect(S, SI); - if (Num < 0) + if (!LiftSelect(S, SI)) continue; - auto SelectBase = cast(S.ReversePtrNumbering[Num]); - SmallVector RefinedPtr{Number(S, SelectBase->getTrueValue()), - Number(S, SelectBase->getFalseValue())}; - S.Refinements[Num] = std::move(RefinedPtr); + if (!isSpecialPtrVec(SI->getType())) { + // TODO: Refinements for vector select + int Num = S.AllPtrNumbering[SI]; + if (Num < 0) + continue; + auto SelectBase = cast(S.ReversePtrNumbering[Num]); + SmallVector RefinedPtr{Number(S, SelectBase->getTrueValue()), + Number(S, SelectBase->getFalseValue())}; + S.Refinements[Num] = std::move(RefinedPtr); + } } else { - SmallVector RefinedPtr{Number(S, SI->getTrueValue()), - Number(S, SI->getFalseValue())}; + SmallVector RefinedPtr; + if (!isSpecialPtrVec(SI->getType())) { + RefinedPtr = { + Number(S, SI->getTrueValue()), + Number(S, SI->getFalseValue()) + }; + } MaybeNoteDef(S, BBS, SI, BBS.Safepoints, std::move(RefinedPtr)); NoteOperandUses(S, BBS, I, BBS.UpExposedUsesUnrooted); } } else if (PHINode *Phi = dyn_cast(&I)) { - if (!isSpecialPtr(Phi->getType()) && !isUnionRep(Phi->getType())) { + if (!isSpecialPtr(Phi->getType()) && !isSpecialPtrVec(Phi->getType()) && + !isUnionRep(Phi->getType())) { continue; } auto nIncoming = Phi->getNumIncomingValues(); // We need to insert an extra phi for the GC root if (!isUnionRep(Phi->getType()) && getValueAddrSpace(Phi) != AddressSpace::Tracked) { - if (S.AllPtrNumbering.find(Phi) != S.AllPtrNumbering.end()) + if (isSpecialPtrVec(Phi->getType()) ? + S.AllVectorNumbering.find(Phi) != S.AllVectorNumbering.end() : + S.AllPtrNumbering.find(Phi) != S.AllPtrNumbering.end()) continue; - auto Num = LiftPhi(S, Phi); - auto lift = cast(S.ReversePtrNumbering[Num]); - S.Refinements[Num] = GetPHIRefinements(lift, S); - PHINumbers.push_back(Num); + LiftPhi(S, Phi, PHINumbers); } else { - MaybeNoteDef(S, BBS, Phi, BBS.Safepoints, GetPHIRefinements(Phi, S)); - PHINumbers.push_back(Number(S, Phi)); + SmallVector PHIRefinements; + if (!isSpecialPtrVec(Phi->getType())) + PHIRefinements = GetPHIRefinements(Phi, S); + MaybeNoteDef(S, BBS, Phi, BBS.Safepoints, std::move(PHIRefinements)); + if (isSpecialPtrVec(Phi->getType())) { + // TODO: Vector refinements + std::vector Nums = NumberVector(S, Phi); + for (int Num : Nums) + PHINumbers.push_back(Num); + } else { + PHINumbers.push_back(Number(S, Phi)); + } for (unsigned i = 0; i < nIncoming; ++i) { BBState &IncomingBBS = S.BBStates[Phi->getIncomingBlock(i)]; NoteUse(S, IncomingBBS, Phi->getIncomingValue(i), IncomingBBS.PhiOuts); @@ -1776,22 +1878,6 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S) { return ChangesMade; } -static Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint) -{ - Value *Val = S.ReversePtrNumbering[Num]; - if (isSpecialPtrVec(Val->getType())) { - const std::vector &AllNums = S.AllVectorNumbering[Val]; - unsigned Idx = 0; - for (; Idx < AllNums.size(); ++Idx) { - if ((unsigned)AllNums[Idx] == Num) - break; - } - Val = ExtractElementInst::Create(Val, ConstantInt::get( - Type::getInt32Ty(Val->getContext()), Idx), "", InsertionPoint); - } - return Val; -} - static void AddInPredLiveOuts(BasicBlock *BB, BitVector &LiveIn, State &S) { bool First = true; diff --git a/test/llvmpasses/gcroots.ll b/test/llvmpasses/gcroots.ll index 958a735b8efca..eb72e2e580d91 100644 --- a/test/llvmpasses/gcroots.ll +++ b/test/llvmpasses/gcroots.ll @@ -414,6 +414,111 @@ top: ret %jl_value_t addrspace(10)* %obj } +define void @vecphi(i1 %cond, <2 x %jl_value_t addrspace(10)*> *%arg) { +; CHECK-LABEL: @vecphi +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 +top: + %ptls = call %jl_value_t*** @julia.ptls_states() + br i1 %cond, label %A, label %B + +A: + br label %common + +B: + %loaded = load <2 x %jl_value_t addrspace(10)*>, <2 x %jl_value_t addrspace(10)*> *%arg + call void @jl_safepoint() + br label %common + +common: + %phi = phi <2 x %jl_value_t addrspace(10)*> [ zeroinitializer, %A ], [ %loaded, %B ] + call void @jl_safepoint() + %el1 = extractelement <2 x %jl_value_t addrspace(10)*> %phi, i32 0 + %el2 = extractelement <2 x %jl_value_t addrspace(10)*> %phi, i32 1 + call void @one_arg_boxed(%jl_value_t addrspace(10)* %el1) + call void @one_arg_boxed(%jl_value_t addrspace(10)* %el2) + unreachable +} + +define i8 @phi_arrayptr(i1 %cond) { +; CHECK-LABEL: @phi_arrayptr +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 +top: + %ptls = call %jl_value_t*** @julia.ptls_states() + br i1 %cond, label %A, label %B + +A: + %obj1 = call %jl_value_t addrspace(10) *@alloc() + %obj2 = call %jl_value_t addrspace(10) *@alloc() + %decayed1 = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11) * + %arrayptrptr1 = bitcast %jl_value_t addrspace(11) *%decayed1 to i8 addrspace(13)* addrspace(11)* + %arrayptr1 = load i8 addrspace(13)*, i8 addrspace(13)* addrspace(11)* %arrayptrptr1 + %decayed2 = addrspacecast %jl_value_t addrspace(10) *%obj2 to %jl_value_t addrspace(11) * + %arrayptrptr2 = bitcast %jl_value_t addrspace(11) *%decayed2 to i8 addrspace(13)* addrspace(11)* + %arrayptr2 = load i8 addrspace(13)*, i8 addrspace(13)* addrspace(11)* %arrayptrptr2 + %insert1 = insertelement <2 x i8 addrspace(13)*> undef, i8 addrspace(13)* %arrayptr1, i32 0 + %insert2 = insertelement <2 x i8 addrspace(13)*> %insert1, i8 addrspace(13)* %arrayptr2, i32 1 + call void @jl_safepoint() + br label %common + +B: + br label %common + +common: +; CHECK: %gclift +; CHECK: %gclift1 +; CHECK-NOT: %gclift2 + %phi = phi <2 x i8 addrspace(13)*> [ %insert2, %A ], [ zeroinitializer, %B ] + call void @jl_safepoint() + %el1 = extractelement <2 x i8 addrspace(13)*> %phi, i32 0 + %el2 = extractelement <2 x i8 addrspace(13)*> %phi, i32 1 + %l1 = load i8, i8 addrspace(13)* %el1 + %l2 = load i8, i8 addrspace(13)* %el2 + %add = add i8 %l1, %l2 + ret i8 %add +} + +define void @vecselect(i1 %cond, <2 x %jl_value_t addrspace(10)*> *%arg) { +; CHECK-LABEL: @vecselect +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 +top: + %ptls = call %jl_value_t*** @julia.ptls_states() + %loaded = load <2 x %jl_value_t addrspace(10)*>, <2 x %jl_value_t addrspace(10)*> *%arg + call void @jl_safepoint() + %select = select i1 %cond, <2 x %jl_value_t addrspace(10)*> zeroinitializer, <2 x %jl_value_t addrspace(10)*> %loaded + call void @jl_safepoint() + %el1 = extractelement <2 x %jl_value_t addrspace(10)*> %select, i32 0 + %el2 = extractelement <2 x %jl_value_t addrspace(10)*> %select, i32 1 + call void @one_arg_boxed(%jl_value_t addrspace(10)* %el1) + call void @one_arg_boxed(%jl_value_t addrspace(10)* %el2) + unreachable +} + +define i8 @select_arrayptr(i1 %cond) { +; CHECK-LABEL: @select_arrayptr +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 +top: + %ptls = call %jl_value_t*** @julia.ptls_states() + %obj1 = call %jl_value_t addrspace(10) *@alloc() + %obj2 = call %jl_value_t addrspace(10) *@alloc() + %decayed1 = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11) * + %arrayptrptr1 = bitcast %jl_value_t addrspace(11) *%decayed1 to i8 addrspace(13)* addrspace(11)* + %arrayptr1 = load i8 addrspace(13)*, i8 addrspace(13)* addrspace(11)* %arrayptrptr1 + %decayed2 = addrspacecast %jl_value_t addrspace(10) *%obj2 to %jl_value_t addrspace(11) * + %arrayptrptr2 = bitcast %jl_value_t addrspace(11) *%decayed2 to i8 addrspace(13)* addrspace(11)* + %arrayptr2 = load i8 addrspace(13)*, i8 addrspace(13)* addrspace(11)* %arrayptrptr2 + %insert1 = insertelement <2 x i8 addrspace(13)*> undef, i8 addrspace(13)* %arrayptr1, i32 0 + %insert2 = insertelement <2 x i8 addrspace(13)*> %insert1, i8 addrspace(13)* %arrayptr2, i32 1 + call void @jl_safepoint() + %select = select i1 %cond, <2 x i8 addrspace(13)*> %insert2, <2 x i8 addrspace(13)*> zeroinitializer + call void @jl_safepoint() + %el1 = extractelement <2 x i8 addrspace(13)*> %select, i32 0 + %el2 = extractelement <2 x i8 addrspace(13)*> %select, i32 1 + %l1 = load i8, i8 addrspace(13)* %el1 + %l2 = load i8, i8 addrspace(13)* %el2 + %add = add i8 %l1, %l2 + ret i8 %add +} + !0 = !{!"jtbaa"} !1 = !{!"jtbaa_const", !0, i64 0} !2 = !{!1, !1, i64 0, i64 1}