From 3ebbe2c25b35811c40f1c97d85afe8d8f3870d9c Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Thu, 21 Jan 2021 01:54:23 +0900 Subject: [PATCH] inference: enable constant propagation for union-split signatures MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The inference precision of certain functions really relies on constant propagation, but currently constant prop' won't happen when a call signature is union split and so sometimes inference ends up looser return type: e.g. ```julia julia> Base.return_types((Union{Tuple{Int,Nothing},Tuple{Int,Missing}},)) do t a, b = t a # I expected a::Int, but a::Union{Missing,Nothing,Int} end |> first Union{Missing, Nothing, Int64} ``` This PR: - enables constant prop' for each union signatures, by calling `abstract_call_method_with_const_args` just after each `abstract_call_method` - refactor `abstract_call_method_with_const_args` into two separate parts, 1.) heuristics to decide whether to do constant prop', 2.) try constant propagation The added test cases will should showcase the cases where the inference result could be improved by that. --- I've not seen notable regression in latency with this PR. Here is a sample benchmark of the impact of this PR on latency, from which I guess this PR is acceptable ? > build time: master (caeacef) ```bash Sysimage built. Summary: Total ─────── 61.615938 seconds Base: ─────── 26.575732 seconds 43.1313% Stdlibs: ──── 35.038024 seconds 56.8652% JULIA usr/lib/julia/sys-o.a Generating REPL precompile statements... 30/30 Executing precompile statements... 1378/1378 Precompilation complete. Summary: Total ─────── 116.417013 seconds Generation ── 81.077365 seconds 69.6439% Execution ─── 35.339648 seconds 30.3561% LINK usr/lib/julia/sys.dylib ``` > build time: this PR ```bash Stdlibs total ──── 34.077962 seconds Sysimage built. Summary: Total ─────── 61.804573 seconds Base: ─────── 27.724077 seconds 44.8576% Stdlibs: ──── 34.077962 seconds 55.1383% JULIA usr/lib/julia/sys-o.a Generating REPL precompile statements... 30/30 Executing precompile statements... 1362/1362 Precompilation complete. Summary: Total ─────── 111.262672 seconds Generation ── 83.535305 seconds 75.0794% Execution ─── 27.727367 seconds 24.9206% LINK usr/lib/julia/sys.dylib ``` > first time to plot: master (caeacef) ```julia julia> using Plots; @time plot(rand(10,3)) 3.614168 seconds (5.47 M allocations: 324.564 MiB, 5.73% gc time, 53.02% compilation time) ``` > first time to plot: this PR ```julia julia> using Plots; @time plot(rand(10,3)) 3.557919 seconds (5.53 M allocations: 328.812 MiB, 2.89% gc time, 51.94% compilation time) ``` --- - fixes #37610 - some part of this code was taken from #37637 - this PR is originally supposed to be alternative and more generalized version of #39296 --- base/compiler/abstractinterpretation.jl | 374 +++++++++++++----------- base/compiler/typeutils.jl | 10 +- test/compiler/inference.jl | 64 ++++ 3 files changed, 276 insertions(+), 172 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 5d2ba91e5d9e5c..909f2b148e9069 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -40,9 +40,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), fullmatch = Bool[] if splitunions splitsigs = switchtupleunion(atype) + split_argtypes = switchtupleunion(argtypes) applicable = Any[] + # arrays like `argtypes`, including constants, for each match + applicable_argtypes = Vector{Any}[] infos = MethodMatchInfo[] - for sig_n in splitsigs + for j in 1:length(splitsigs) + sig_n = splitsigs[j] mt = ccall(:jl_method_table_for, Any, (Any,), sig_n) if mt === nothing add_remark!(interp, sv, "Could not identify method table for call") @@ -56,6 +60,10 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end push!(infos, MethodMatchInfo(matches)) append!(applicable, matches) + for _ in 1:length(matches) + push!(applicable_argtypes, split_argtypes[j]) + end + # @assert argtypes_to_type(split_argtypes[j]) === sig_n "invalid union split" valid_worlds = intersect(valid_worlds, matches.valid_worlds) thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches) found = false @@ -91,15 +99,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), info = MethodMatchInfo(matches) applicable = matches.matches valid_worlds = matches.valid_worlds + applicable_argtypes = nothing end update_valid_age!(sv, valid_worlds) applicable = applicable::Array{Any,1} napplicable = length(applicable) rettype = Bottom edgecycle = false - edges = Any[] - nonbot = 0 # the index of the only non-Bottom inference result if > 0 - seen = 0 # number of signatures actually inferred + edges = MethodInstance[] istoplevel = sv.linfo.def isa Module multiple_matches = napplicable > 1 @@ -122,8 +129,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), break end sigtuple = unwrap_unionall(sig)::DataType - splitunions = false this_rt = Bottom + splitunions = false # TODO: splitunions = 1 < unionsplitcost(sigtuple.parameters) * napplicable <= InferenceParams(interp).MAX_UNION_SPLITTING # currently this triggers a bug in inference recursion detection if splitunions @@ -134,8 +141,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), push!(edges, edge) end edgecycle |= edgecycle1::Bool + this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] + const_rt = abstract_call_method_with_const_args(interp, rt, f, this_argtypes, match, sv, edgecycle) + if const_rt !== rt && const_rt ⊑ rt + rt = const_rt + end this_rt = tmerge(this_rt, rt) - this_rt === Any && break + if this_rt === Any + break + end end else this_rt, edgecycle1, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv) @@ -143,33 +157,19 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), if edge !== nothing push!(edges, edge) end - end - if this_rt !== Bottom - if nonbot === 0 - nonbot = i - else - nonbot = -1 + this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] + const_this_rt = abstract_call_method_with_const_args(interp, this_rt, f, this_argtypes, match, sv, edgecycle) + if const_this_rt !== this_rt && const_this_rt ⊑ this_rt + this_rt = const_this_rt end end - seen += 1 rettype = tmerge(rettype, this_rt) - rettype === Any && break - end - # try constant propagation if only 1 method is inferred to non-Bottom - # this is in preparation for inlining, or improving the return result - is_unused = call_result_unused(sv) - if nonbot > 0 && seen == napplicable && (!edgecycle || !is_unused) && - is_improvable(rettype) && InferenceParams(interp).ipo_constant_propagation - # if there's a possibility we could constant-propagate a better result - # (hopefully without doing too much work), try to do that now - # TODO: it feels like this could be better integrated into abstract_call_method / typeinf_edge - const_rettype = abstract_call_method_with_const_args(interp, rettype, f, argtypes, applicable[nonbot]::MethodMatch, sv, edgecycle) - if const_rettype ⊑ rettype - # use the better result, if it's a refinement of rettype - rettype = const_rettype + if rettype === Any + break end end - if is_unused && !(rettype === Bottom) + + if call_result_unused(sv) && !(rettype === Bottom) add_remark!(interp, sv, "Call result type was widened because the return value is unused") # We're mainly only here because the optimizer might want this code, # but we ourselves locally don't typically care about it locally @@ -205,148 +205,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), return CallMeta(rettype, info) end - -function const_prop_profitable(@nospecialize(arg)) - # have new information from argtypes that wasn't available from the signature - if isa(arg, PartialStruct) - for b in arg.fields - isconstType(b) && return true - const_prop_profitable(b) && return true - end - end - isa(arg, Const) || return true - val = arg.val - # don't consider mutable values or Strings useful constants - return isa(val, Symbol) || isa(val, Type) || (!isa(val, String) && !ismutable(val)) -end - -# This is a heuristic to avoid trying to const prop through complicated functions -# where we would spend a lot of time, but are probably unliekly to get an improved -# result anyway. -function const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance) - # Peek at the inferred result for the function to determine if the optimizer - # was able to cut it down to something simple (inlineable in particular). - # If so, there's a good chance we might be able to const prop all the way - # through and learn something new. - code = get(code_cache(interp), mi, nothing) - declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source) - cache_inlineable = declared_inline - if isdefined(code, :inferred) && !cache_inlineable - cache_inf = code.inferred - if !(cache_inf === nothing) - cache_src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), cache_inf) - cache_src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), cache_inf) - cache_inlineable = cache_src_inferred && cache_src_inlineable - end - end - if !cache_inlineable - return false - end - return true -end - -function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, sv::InferenceState, edgecycle::Bool) - method = match.method - nargs::Int = method.nargs - method.isva && (nargs -= 1) - length(argtypes) >= nargs || return Any - haveconst = false - allconst = true - # see if any or all of the arguments are constant and propagating constants may be worthwhile - for a in argtypes - a = widenconditional(a) - if allconst && !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) - allconst = false - end - if !haveconst && has_nontrivial_const_info(a) && const_prop_profitable(a) - haveconst = true - end - if haveconst && !allconst - break - end - end - haveconst || improvable_via_constant_propagation(rettype) || return Any - force_inference = method.aggressive_constprop || InferenceParams(interp).aggressive_constant_propagation - if !force_inference && nargs > 1 - if istopfunction(f, :getindex) || istopfunction(f, :setindex!) - arrty = argtypes[2] - # don't propagate constant index into indexing of non-constant array - if arrty isa Type && arrty <: AbstractArray && !issingletontype(arrty) - return Any - elseif arrty ⊑ Array - return Any - end - elseif istopfunction(f, :iterate) - itrty = argtypes[2] - if itrty ⊑ Array - return Any - end - end - end - if !force_inference && !allconst && - (istopfunction(f, :+) || istopfunction(f, :-) || istopfunction(f, :*) || - istopfunction(f, :(==)) || istopfunction(f, :!=) || - istopfunction(f, :<=) || istopfunction(f, :>=) || istopfunction(f, :<) || istopfunction(f, :>) || - istopfunction(f, :<<) || istopfunction(f, :>>)) - # it is almost useless to inline the op of when all the same type, - # but highly worthwhile to inline promote of a constant - length(argtypes) > 2 || return Any - t1 = widenconst(argtypes[2]) - all_same = true - for i in 3:length(argtypes) - if widenconst(argtypes[i]) !== t1 - all_same = false - break - end - end - all_same && return Any - end - if istopfunction(f, :getproperty) || istopfunction(f, :setproperty!) - force_inference = true - end - force_inference |= allconst - mi = specialize_method(match, !force_inference) - mi === nothing && return Any - mi = mi::MethodInstance - # decide if it's likely to be worthwhile - if !force_inference && !const_prop_heuristic(interp, method, mi) - return Any - end - inf_cache = get_inference_cache(interp) - inf_result = cache_lookup(mi, argtypes, inf_cache) - if inf_result === nothing - if edgecycle - # if there might be a cycle, check to make sure we don't end up - # calling ourselves here. - infstate = sv - cyclei = 0 - while !(infstate === nothing) - if method === infstate.linfo.def && any(infstate.result.overridden_by_const) - return Any - end - if cyclei < length(infstate.callers_in_cycle) - cyclei += 1 - infstate = infstate.callers_in_cycle[cyclei] - else - cyclei = 0 - infstate = infstate.parent - end - end - end - inf_result = InferenceResult(mi, argtypes) - frame = InferenceState(inf_result, #=cache=#false, interp) - frame === nothing && return Any # this is probably a bad generated function (unsound), but just ignore it - frame.parent = sv - push!(inf_cache, inf_result) - typeinf(interp, frame) || return Any - end - result = inf_result.result - # if constant inference hits a cycle, just bail out - isa(result, InferenceState) && return Any - add_backedge!(inf_result.linfo, sv) - return result -end - const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result." function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState) @@ -506,6 +364,182 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp return rt, edgecycle, edge end +function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), + @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, + sv::InferenceState, edgecycle::Bool) + mi = maybe_get_const_prop_profitable(interp, rettype, f, argtypes, match, sv, edgecycle) + mi === nothing && return Any + # try constant prop' + inf_cache = get_inference_cache(interp) + inf_result = cache_lookup(mi, argtypes, inf_cache) + if inf_result === nothing + if edgecycle + # if there might be a cycle, check to make sure we don't end up + # calling ourselves here. + infstate = sv + cyclei = 0 + while !(infstate === nothing) + if match.method === infstate.linfo.def && any(infstate.result.overridden_by_const) + return Any + end + if cyclei < length(infstate.callers_in_cycle) + cyclei += 1 + infstate = infstate.callers_in_cycle[cyclei] + else + cyclei = 0 + infstate = infstate.parent + end + end + end + inf_result = InferenceResult(mi, argtypes) + frame = InferenceState(inf_result, #=cache=#false, interp) + frame === nothing && return Any # this is probably a bad generated function (unsound), but just ignore it + frame.parent = sv + push!(inf_cache, inf_result) + typeinf(interp, frame) || return Any + end + result = inf_result.result + # if constant inference hits a cycle, just bail out + isa(result, InferenceState) && return Any + add_backedge!(mi, sv) + return result +end + +# if there's a possibility we could get a better result (hopefully without doing too much work) +# returns `MethodInstance` with constant arguments, returns nothing otherwise +function maybe_get_const_prop_profitable(interp::AbstractInterpreter, @nospecialize(rettype), + @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, + sv::InferenceState, edgecycle::Bool) + const_prop_entry_heuristic(interp, rettype, sv, edgecycle) || return nothing + method = match.method + nargs::Int = method.nargs + method.isva && (nargs -= 1) + length(argtypes) >= nargs || return nothing + const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, rettype) || return nothing + allconst = is_allconst(argtypes) + force = force_const_prop(interp, f, method) + force || const_prop_function_heuristic(interp, f, argtypes, nargs, allconst) || return nothing + force |= allconst + mi = specialize_method(match, !force) + mi === nothing && return nothing + mi = mi::MethodInstance + force || const_prop_methodinstance_heuristic(interp, method, mi) || return nothing + return mi +end + +function const_prop_entry_heuristic(interp::AbstractInterpreter, @nospecialize(rettype), sv::InferenceState, edgecycle::Bool) + call_result_unused(sv) && edgecycle && return false + return is_improvable(rettype) && InferenceParams(interp).ipo_constant_propagation +end + +# see if propagating constants may be worthwhile +function const_prop_argument_heuristic(interp::AbstractInterpreter, argtypes::Vector{Any}) + for a in argtypes + a = widenconditional(a) + if has_nontrivial_const_info(a) && is_const_prop_profitable_arg(a) + return true + end + end + return false +end + +function is_const_prop_profitable_arg(@nospecialize(arg)) + # have new information from argtypes that wasn't available from the signature + if isa(arg, PartialStruct) + for b in arg.fields + isconstType(b) && return true + is_const_prop_profitable_arg(b) && return true + end + end + isa(arg, Const) || return true + val = arg.val + # don't consider mutable values or Strings useful constants + return isa(val, Symbol) || isa(val, Type) || (!isa(val, String) && !ismutable(val)) +end + +function const_prop_rettype_heuristic(interp::AbstractInterpreter, @nospecialize(rettype)) + return improvable_via_constant_propagation(rettype) +end + +function is_allconst(argtypes::Vector{Any}) + for a in argtypes + a = widenconditional(a) + if !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) + return false + end + end + return true +end + +function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method::Method) + return method.aggressive_constprop || + InferenceParams(interp).aggressive_constant_propagation || + istopfunction(f, :getproperty) || + istopfunction(f, :setproperty!) +end + +function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, nargs::Int, allconst::Bool) + if nargs > 1 + if istopfunction(f, :getindex) || istopfunction(f, :setindex!) + arrty = argtypes[2] + # don't propagate constant index into indexing of non-constant array + if arrty isa Type && arrty <: AbstractArray && !issingletontype(arrty) + return false + elseif arrty ⊑ Array + return false + end + elseif istopfunction(f, :iterate) + itrty = argtypes[2] + if itrty ⊑ Array + return false + end + end + end + if !allconst && (istopfunction(f, :+) || istopfunction(f, :-) || istopfunction(f, :*) || + istopfunction(f, :(==)) || istopfunction(f, :!=) || + istopfunction(f, :<=) || istopfunction(f, :>=) || istopfunction(f, :<) || istopfunction(f, :>) || + istopfunction(f, :<<) || istopfunction(f, :>>)) + # it is almost useless to inline the op of when all the same type, + # but highly worthwhile to inline promote of a constant + length(argtypes) > 2 || return false + t1 = widenconst(argtypes[2]) + all_same = true + for i in 3:length(argtypes) + if widenconst(argtypes[i]) !== t1 + all_same = false + break + end + end + return !all_same + end + return true +end + +# This is a heuristic to avoid trying to const prop through complicated functions +# where we would spend a lot of time, but are probably unliekly to get an improved +# result anyway. +function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance) + # Peek at the inferred result for the function to determine if the optimizer + # was able to cut it down to something simple (inlineable in particular). + # If so, there's a good chance we might be able to const prop all the way + # through and learn something new. + code = get(code_cache(interp), mi, nothing) + declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source) + cache_inlineable = declared_inline + if isdefined(code, :inferred) && !cache_inlineable + cache_inf = code.inferred + if !(cache_inf === nothing) + cache_src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), cache_inf) + cache_src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), cache_inf) + cache_inlineable = cache_src_inferred && cache_src_inlineable + end + end + if !cache_inlineable + return false + end + return true +end + # This is only for use with `Conditional`. # In general, usage of this is wrong. function ssa_def_slot(@nospecialize(arg), sv::InferenceState) diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index 10f4a8b949da56..c86ec36dfb319b 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -181,10 +181,16 @@ function switchtupleunion(@nospecialize(ty)) return _switchtupleunion(Any[tparams...], length(tparams), [], ty) end +switchtupleunion(t::Vector{Any}) = _switchtupleunion(t, length(t), [], nothing) + function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt)) if i == 0 - tpl = rewrap_unionall(Tuple{t...}, origt) - push!(tunion, tpl) + if origt === nothing + push!(tunion, copy(t)) + else + tpl = rewrap_unionall(Tuple{t...}, origt) + push!(tunion, tpl) + end else ti = t[i] if isa(ti, Union) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index cd005d7f39d435..3f4e9c8f839087 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3020,3 +3020,67 @@ end # Bare Core.Argument in IR @eval f_bare_argument(x) = $(Core.Argument(2)) @test Base.return_types(f_bare_argument, (Int,))[1] == Int + +@testset "constant prop' for union split signature" begin + anonymous_module() = Core.eval(@__MODULE__, :(module $(gensym()) end))::Module + + # indexing into tuples really relies on constant prop', and we will get looser result + # (`Union{Int,String,Char}`) if constant prop' doesn't happen for splitunion signatures + tt = (Union{Tuple{Int,String},Tuple{Int,Char}},) + @test Base.return_types(tt) do t + getindex(t, 1) + end == Any[Int] + @test Base.return_types(tt) do t + getindex(t, 2) + end == Any[Union{String,Char}] + @test Base.return_types(tt) do t + a, b = t + a + end == Any[Int] + @test Base.return_types(tt) do t + a, b = t + b + end == Any[Union{String,Char}] + + @test (@eval anonymous_module() begin + struct F32 + val::Float32 + _v::Int + end + struct F64 + val::Float64 + _v::Int + end + Base.return_types((Union{F32,F64},)) do f + f.val + end + end) == Any[Union{Float32,Float64}] + + @test (@eval anonymous_module() begin + struct F32 + val::Float32 + _v + end + struct F64 + val::Float64 + _v + end + Base.return_types((Union{F32,F64},)) do f + f.val + end + end) == Any[Union{Float32,Float64}] + + @test Base.return_types((Union{Tuple{Nothing,Any,Any},Tuple{Nothing,Any}},)) do t + getindex(t, 1) + end == Any[Nothing] + + # issue #37610 + @test Base.return_types((typeof(("foo" => "bar", "baz" => nothing)), Int)) do a, i + y = iterate(a, i) + if y !== nothing + (k, v), st = y + return k, v + end + return y + end == Any[Union{Nothing, Tuple{String, Union{Nothing, String}}}] +end