From bf346fbed66fc92a2c54249a95a5f4174cc37a9c Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 27 Oct 2021 02:17:26 +0900 Subject: [PATCH] optimizer: inline abstract union-split callsite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently the optimizer handles abstract callsite only when there is a single dispatch candidate (in most cases), and so inlining and static-dispatch are prohibited when the callsite is union-split (in other word, union-split happens only when all the dispatch candidates are concrete). However, there are certain patterns of code (most notably our Julia-level compiler code) that inherently need to deal with abstract callsite. The following example is taken from `Core.Compiler` utility: ```julia julia> @inline isType(@nospecialize t) = isa(t, DataType) && t.name === Type.body.name isType (generic function with 1 method) julia> code_typed((Any,)) do x # abstract, but no union-split, successful inlining isType(x) end |> only CodeInfo( 1 ─ %1 = (x isa Main.DataType)::Bool └── goto #3 if not %1 2 ─ %3 = π (x, DataType) │ %4 = Base.getfield(%3, :name)::Core.TypeName │ %5 = Base.getfield(Type{T}, :name)::Core.TypeName │ %6 = (%4 === %5)::Bool └── goto #4 3 ─ goto #4 4 ┄ %9 = φ (#2 => %6, #3 => false)::Bool └── return %9 ) => Bool julia> code_typed((Union{Type,Nothing},)) do x # abstract, union-split, unsuccessful inlining isType(x) end |> only CodeInfo( 1 ─ %1 = (isa)(x, Nothing)::Bool └── goto #3 if not %1 2 ─ goto #4 3 ─ %4 = Main.isType(x)::Bool └── goto #4 4 ┄ %6 = φ (#2 => false, #3 => %4)::Bool └── return %6 ) => Bool ``` (note that this is a limitation of the inlining algorithm, and so any user-provided hints like callsite inlining annotation doesn't help here) This commit enables inlining and static dispatch for abstract union-split callsite. The core idea here is that we can simulate our dispatch semantics by generating `isa` checks in order of the specialities of dispatch candidates: ```julia julia> code_typed((Union{Type,Nothing},)) do x # union-split, unsuccessful inlining isType(x) end |> only CodeInfo( 1 ─ %1 = (isa)(x, Nothing)::Bool └── goto #3 if not %1 2 ─ goto #9 3 ─ %4 = (isa)(x, Type)::Bool └── goto #8 if not %4 4 ─ %6 = π (x, Type) │ %7 = (%6 isa Main.DataType)::Bool └── goto #6 if not %7 5 ─ %9 = π (%6, DataType) │ %10 = Base.getfield(%9, :name)::Core.TypeName │ %11 = Base.getfield(Type{T}, :name)::Core.TypeName │ %12 = (%10 === %11)::Bool └── goto #7 6 ─ goto #7 7 ┄ %15 = φ (#5 => %12, #6 => false)::Bool └── goto #9 8 ─ Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{} └── unreachable 9 ┄ %19 = φ (#2 => false, #7 => %15)::Bool └── return %19 ) => Bool ``` Inlining/static-dispatch of abstract union-split callsite will improve the performance in such situations (and so this commit will improve the latency of our JIT compilation). Especially, this commit helps us avoid excessive specializations of `Core.Compiler` code by statically-resolving `@nospecialize`d callsites, and as the result, the # of precompiled statements is now reduced from `1956` ([`master`](dc45d776a900ef17581a842952c51297065afa3a)) to `1901` (this commit). And also, as a side effect, the implementation of our inlining algorithm gets much simplified now since we no longer need the previous special handlings for abstract callsites. One possible drawback would be increased code size. This change seems to certainly increase the size of sysimage, but I think these numbers are in an acceptable range: > [`master`](dc45d776a900ef17581a842952c51297065afa3a) ``` ❯ du -sh usr/lib/julia/* 17M usr/lib/julia/corecompiler.ji 188M usr/lib/julia/sys-o.a 164M usr/lib/julia/sys.dylib 23M usr/lib/julia/sys.dylib.dSYM 101M usr/lib/julia/sys.ji ``` > this commit ``` ❯ du -sh usr/lib/julia/* 17M usr/lib/julia/corecompiler.ji 190M usr/lib/julia/sys-o.a 166M usr/lib/julia/sys.dylib 23M usr/lib/julia/sys.dylib.dSYM 102M usr/lib/julia/sys.ji ``` --- base/compiler/ssair/inlining.jl | 114 +++++++++----------------------- base/sort.jl | 2 +- test/compiler/inline.jl | 78 ++++++++++++++++++++-- 3 files changed, 107 insertions(+), 87 deletions(-) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index ab93375db4d0e2..e1af48677e3cbf 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, push!(from_bbs, length(state.new_cfg_blocks)) # TODO: Right now we unconditionally generate a fallback block # in case of subtyping errors - This is probably unnecessary. - if i != length(cases) || (!fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig))) + if i != length(cases) || (!fully_covered || (!params.trust_inference)) # This block will have the next condition or the final else case push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx))) push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks)) @@ -313,7 +313,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector spec = item.spec::ResolvedInliningSpec sparam_vals = item.mi.sparam_vals def = item.mi.def::Method - inline_cfg = spec.ir.cfg linetable_offset::Int32 = length(linetable) # Append the linetable of the inlined function to our line table inlined_at = Int(compact.result[idx][:line]) @@ -472,6 +471,14 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, pn = PhiNode() local bb = compact.active_result_bb @assert length(bbs) >= length(cases) + # NOTE we are going to generate `isa` checks that correspond to the signatures of + # union-split dispatch candidates in order to simulate the dispatch semantics, + # and inline their bodies within each `isa`-conditional block -- and since we may + # deal with abstract union-split callsites here, these dispatch candidates need + # to be sorted in order of their signature specificity. + # Fortunately, ml_matches already sorted them in that way, so we can just process them + # in order here, assuming we haven't changed their order somewhere up to this point + # TODO assert that they are in the same order as sorted by ml_matches for i in 1:length(cases) ithcase = cases[i] mtype = ithcase.sig::DataType # checked within `handle_cases!` @@ -480,8 +487,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, cond = true nparams = fieldcount(atype) @assert nparams == fieldcount(mtype) - if i != length(cases) || !fully_covered || - (!params.trust_inference && isdispatchtuple(cases[i].sig)) + if i != length(cases) || !fully_covered || !params.trust_inference for i = 1:nparams a, m = fieldtype(atype, i), fieldtype(mtype, i) # If this is always true, we don't need to check for it @@ -538,7 +544,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, bb += 1 # We're now in the fall through block, decide what to do if fully_covered - if !params.trust_inference && isdispatchtuple(cases[end].sig) + if !params.trust_inference e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR) insert_node_here!(compact, NewInstruction(e, Union{}, line)) insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line)) @@ -561,7 +567,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect state = CFGInliningState(ir) for (idx, item) in todo if isa(item, UnionSplit) - cfg_inline_unionsplit!(ir, idx, item::UnionSplit, state, params) + cfg_inline_unionsplit!(ir, idx, item, state, params) else item = item::InliningTodo spec = item.spec::ResolvedInliningSpec @@ -1175,12 +1181,8 @@ function analyze_single_call!( sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}}) argtypes = sig.argtypes cases = InliningCase[] - local only_method = nothing # keep track of whether there is one matching method - local meth::MethodLookupResult + local any_fully_covered = false local handled_all_cases = true - local any_covers_full = false - local revisit_idx = nothing - for i in 1:length(infos) meth = infos[i].results if meth.ambig @@ -1191,66 +1193,20 @@ function analyze_single_call!( # No applicable methods; try next union split handled_all_cases = false continue - else - if length(meth) == 1 && only_method !== false - if only_method === nothing - only_method = meth[1].method - elseif only_method !== meth[1].method - only_method = false - end - else - only_method = false - end end - for (j, match) in enumerate(meth) - any_covers_full |= match.fully_covers - if !isdispatchtuple(match.spec_types) - if !match.fully_covers - handled_all_cases = false - continue - end - if revisit_idx === nothing - revisit_idx = (i, j) - else - handled_all_cases = false - revisit_idx = nothing - end - else - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) - end + for match in meth + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) + any_fully_covered |= match.fully_covers end end - atype = argtypes_to_type(argtypes) - if handled_all_cases && revisit_idx !== nothing - # If there's only one case that's not a dispatchtuple, we can - # still unionsplit by visiting all the other cases first. - # This is useful for code like: - # foo(x::Int) = 1 - # foo(@nospecialize(x::Any)) = 2 - # where we where only a small number of specific dispatchable - # cases are split off from an ::Any typed fallback. - (i, j) = revisit_idx - match = infos[i].results[j] - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) - elseif length(cases) == 0 && only_method isa Method - # if the signature is fully covered and there is only one applicable method, - # we can try to inline it even if the signature is not a dispatch tuple. - # -- But don't try it if we already tried to handle the match in the revisit_idx - # case, because that'll (necessarily) be the same method. - if length(infos) > 1 - (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), - atype, only_method.sig)::SimpleVector - match = MethodMatch(metharg, methsp::SimpleVector, only_method, true) - else - @assert length(meth) == 1 - match = meth[1] - end - handle_match!(match, argtypes, flag, state, cases, true) || return nothing - any_covers_full = handled_all_cases = match.fully_covers + if !handled_all_cases + # if we've not seen all candidates, union split is valid only for dispatch tuples + filter!(case::InliningCase->isdispatchtuple(case.sig), cases) end - handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) + handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases, + handled_all_cases & any_fully_covered, todo, state.params) end # similar to `analyze_single_call!`, but with constant results @@ -1261,8 +1217,8 @@ function handle_const_call!( (; call, results) = cinfo infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches cases = InliningCase[] + local any_fully_covered = false local handled_all_cases = true - local any_covers_full = false local j = 0 for i in 1:length(infos) meth = infos[i].results @@ -1278,32 +1234,26 @@ function handle_const_call!( for match in meth j += 1 result = results[j] - any_covers_full |= match.fully_covers + any_fully_covered |= match.fully_covers if isa(result, ConstResult) case = const_result_item(result, state) push!(cases, InliningCase(result.mi.specTypes, case)) elseif isa(result, InferenceResult) - handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases) + handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases, true) else @assert result === nothing - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) end end end - # if the signature is fully covered and there is only one applicable method, - # we can try to inline it even if the signature is not a dispatch tuple - atype = argtypes_to_type(argtypes) - if length(cases) == 0 - length(results) == 1 || return nothing - result = results[1] - isa(result, InferenceResult) || return nothing - handle_inf_result!(result, argtypes, flag, state, cases, true) || return nothing - spec_types = cases[1].sig - any_covers_full = handled_all_cases = atype <: spec_types + if !handled_all_cases + # if we've not seen all candidates, union split is valid only for dispatch tuples + filter!(case::InliningCase->isdispatchtuple(case.sig), cases) end - handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) + handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases, + handled_all_cases & any_fully_covered, todo, state.params) end function handle_match!( @@ -1313,7 +1263,6 @@ function handle_match!( allow_abstract || isdispatchtuple(spec_types) || return false item = analyze_method!(match, argtypes, flag, state) item === nothing && return false - _any(case->case.sig === spec_types, cases) && return true push!(cases, InliningCase(spec_types, item)) return true end @@ -1445,7 +1394,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) analyze_single_call!(ir, idx, stmt, infos, flag, sig, state, todo) end - todo + + return todo end function linear_inline_eligible(ir::IRCode) diff --git a/base/sort.jl b/base/sort.jl index d26e9a4b093328..981eea35d96ab8 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -5,7 +5,7 @@ module Sort import ..@__MODULE__, ..parentmodule const Base = parentmodule(@__MODULE__) using .Base.Order -using .Base: copymutable, LinearIndices, length, (:), +using .Base: copymutable, LinearIndices, length, (:), iterate, eachindex, axes, first, last, similar, zip, OrdinalRange, AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline, AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !, diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index 7619d4e8a03085..275a59b0367301 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -810,6 +810,76 @@ let @test invoke(Any[10]) === false end +# test union-split, non-dispatchtuple callsite inlining + +@constprop :none @noinline abstract_unionsplit(@nospecialize x::Any) = Base.inferencebarrier(:Any) +@constprop :none @noinline abstract_unionsplit(@nospecialize x::Number) = Base.inferencebarrier(:Number) +let src = code_typed1((Any,)) do x + abstract_unionsplit(x) + end + @test count(isinvoke(:abstract_unionsplit), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit(x) + end + @test count(isinvoke(:abstract_unionsplit), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Type) = Base.inferencebarrier(:Any) +@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Number) = Base.inferencebarrier(:Number) +let src = code_typed1((Any,)) do x + abstract_unionsplit_fallback(x) + end + @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 + @test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit_fallback(x) + end + @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Any) = (c && println("erase me"); typeof(x)) +@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x)) +let src = code_typed1((Any,)) do x + abstract_unionsplit(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Type) = (c && println("erase me"); typeof(x)) +@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x)) +let src = code_typed1((Any,)) do x + abstract_unionsplit_fallback(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit_fallback(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + # issue 43104 @inline isGoodType(@nospecialize x::Type) = @@ -1090,11 +1160,11 @@ end global x44200::Int = 0 function f44200() - global x = 0 - while x < 10 - x += 1 + global x44200 = 0 + while x44200 < 10 + x44200 += 1 end - x + x44200 end let src = code_typed1(f44200) @test count(x -> isa(x, Core.PiNode), src.code) == 0