From 1f21f2d6feacaa5bf3f7ea212c64e98bd66fd265 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Thu, 11 Mar 2021 05:25:29 +0900 Subject: [PATCH] inference: enable constant propagation for union-split signatures (#39305) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The inference precision of certain functions really relies on constant propagation, but currently constant prop' won't happen when a call signature is union split and so sometimes inference ends up looser return type: e.g. ```julia julia> Base.return_types((Union{Tuple{Int,Nothing},Tuple{Int,Missing}},)) do t a, b = t a # I expected a::Int, but a::Union{Missing,Nothing,Int} end |> first Union{Missing, Nothing, Int64} ``` This PR: - enables constant prop' for each union signatures, by calling `abstract_call_method_with_const_args` just after each `abstract_call_method` - refactor `abstract_call_method_with_const_args` into two separate parts, 1.) heuristics to decide whether to do constant prop', 2.) try constant propagation The added test cases will should showcase the cases where the inference result could be improved by that. --- I've not seen notable regression in latency with this PR. Here is a sample benchmark of the impact of this PR on latency, from which I guess this PR is acceptable ? > build time: master (caeacef) ```bash Sysimage built. Summary: Total ─────── 61.615938 seconds Base: ─────── 26.575732 seconds 43.1313% Stdlibs: ──── 35.038024 seconds 56.8652% JULIA usr/lib/julia/sys-o.a Generating REPL precompile statements... 30/30 Executing precompile statements... 1378/1378 Precompilation complete. Summary: Total ─────── 116.417013 seconds Generation ── 81.077365 seconds 69.6439% Execution ─── 35.339648 seconds 30.3561% LINK usr/lib/julia/sys.dylib ``` > build time: this PR ```bash Stdlibs total ──── 34.077962 seconds Sysimage built. Summary: Total ─────── 61.804573 seconds Base: ─────── 27.724077 seconds 44.8576% Stdlibs: ──── 34.077962 seconds 55.1383% JULIA usr/lib/julia/sys-o.a Generating REPL precompile statements... 30/30 Executing precompile statements... 1362/1362 Precompilation complete. Summary: Total ─────── 111.262672 seconds Generation ── 83.535305 seconds 75.0794% Execution ─── 27.727367 seconds 24.9206% LINK usr/lib/julia/sys.dylib ``` > first time to plot: master (caeacef) ```julia julia> using Plots; @time plot(rand(10,3)) 3.614168 seconds (5.47 M allocations: 324.564 MiB, 5.73% gc time, 53.02% compilation time) ``` > first time to plot: this PR ```julia julia> using Plots; @time plot(rand(10,3)) 3.557919 seconds (5.53 M allocations: 328.812 MiB, 2.89% gc time, 51.94% compilation time) ``` --- - fixes #37610 - some part of this code was taken from #37637 - this PR is originally supposed to be alternative and more generalized version of #39296 --- base/compiler/abstractinterpretation.jl | 468 ++++++++++++------------ base/compiler/ssair/inlining.jl | 46 ++- base/compiler/stmtinfo.jl | 2 +- base/compiler/typeutils.jl | 10 +- test/compiler/inference.jl | 96 +++++ 5 files changed, 379 insertions(+), 243 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index d05ded6af04e2..2fe4932b3ae8c 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -36,15 +36,17 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), return CallMeta(Any, false) end valid_worlds = WorldRange() - atype_params = unwrap_unionall(atype).parameters - splitunions = 1 < unionsplitcost(atype_params) <= InferenceParams(interp).MAX_UNION_SPLITTING + # NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type + splitunions = 1 < unionsplitcost(argtypes) <= InferenceParams(interp).MAX_UNION_SPLITTING mts = Core.MethodTable[] fullmatch = Bool[] if splitunions - splitsigs = switchtupleunion(atype) + split_argtypes = switchtupleunion(argtypes) applicable = Any[] + applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match infos = MethodMatchInfo[] - for sig_n in splitsigs + for arg_n in split_argtypes + sig_n = argtypes_to_type(arg_n) mt = ccall(:jl_method_table_for, Any, (Any,), sig_n) if mt === nothing add_remark!(interp, sv, "Could not identify method table for call") @@ -57,7 +59,10 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), return CallMeta(Any, false) end push!(infos, MethodMatchInfo(matches)) - append!(applicable, matches) + for m in matches + push!(applicable, m) + push!(applicable_argtypes, arg_n) + end valid_worlds = intersect(valid_worlds, matches.valid_worlds) thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches) found = false @@ -93,16 +98,17 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), info = MethodMatchInfo(matches) applicable = matches.matches valid_worlds = matches.valid_worlds + applicable_argtypes = nothing end update_valid_age!(sv, valid_worlds) applicable = applicable::Array{Any,1} napplicable = length(applicable) rettype = Bottom - edgecycle = false edges = MethodInstance[] conditionals = nothing # keeps refinement information of call argument types when the return type is boolean - nonbot = 0 # the index of the only non-Bottom inference result if > 0 seen = 0 # number of signatures actually inferred + any_const_result = false + const_results = Union{InferenceResult,Nothing}[] multiple_matches = napplicable > 1 if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch) @@ -123,46 +129,54 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), rettype = Any break end - sigtuple = unwrap_unionall(sig)::DataType this_rt = Bottom splitunions = false - # TODO: splitunions = 1 < unionsplitcost(sigtuple.parameters) * napplicable <= InferenceParams(interp).MAX_UNION_SPLITTING - # this used to trigger a bug in inference recursion detection, and is unmaintained now + # TODO: this used to trigger a bug in inference recursion detection, and is unmaintained now + # sigtuple = unwrap_unionall(sig)::DataType + # splitunions = 1 < unionsplitcost(sigtuple.parameters) * napplicable <= InferenceParams(interp).MAX_UNION_SPLITTING if splitunions splitsigs = switchtupleunion(sig) for sig_n in splitsigs - rt, edgecycle1, edge = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv) - edgecycle |= edgecycle1::Bool + rt, edgecycle, edge = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv) if edge !== nothing push!(edges, edge) end + this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] + const_rt, const_result = abstract_call_method_with_const_args(interp, rt, f, this_argtypes, match, sv, edgecycle) + if const_rt !== rt && const_rt ⊑ rt + rt = const_rt + end + push!(const_results, const_result) + if const_result !== nothing + any_const_result = true + end this_rt = tmerge(this_rt, rt) if bail_out_call(interp, this_rt, sv) break end end else - this_rt, edgecycle1, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv) - edgecycle |= edgecycle1::Bool + this_rt, edgecycle, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv) if edge !== nothing push!(edges, edge) end + # try constant propagation with argtypes for this match + # this is in preparation for inlining, or improving the return result + this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] + const_this_rt, const_result = abstract_call_method_with_const_args(interp, this_rt, f, this_argtypes, match, sv, edgecycle) + if const_this_rt !== this_rt && const_this_rt ⊑ this_rt + this_rt = const_this_rt + end + push!(const_results, const_result) + if const_result !== nothing + any_const_result = true + end end this_conditional = ignorelimited(this_rt) this_rt = widenwrappedconditional(this_rt) @assert !(this_conditional isa Conditional) "invalid lattice element returned from inter-procedural context" - if this_rt !== Bottom - if nonbot === 0 - nonbot = i - else - nonbot = -1 - end - end seen += 1 rettype = tmerge(rettype, this_rt) - if bail_out_call(interp, rettype, sv) - break - end if this_conditional !== Bottom && is_lattice_bool(rettype) && fargs !== nothing if conditionals === nothing conditionals = Any[Bottom for _ in 1:length(argtypes)], @@ -183,54 +197,18 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), conditionals[2][i] = tmerge(conditionals[2][i], elsetype) end end - end - # try constant propagation if only 1 method is inferred to non-Bottom - # this is in preparation for inlining, or improving the return result - is_unused = call_result_unused(sv) - if nonbot > 0 && seen == napplicable && (!edgecycle || !is_unused) && - is_improvable(rettype) && InferenceParams(interp).ipo_constant_propagation - # if there's a possibility we could constant-propagate a better result - # (hopefully without doing too much work), try to do that now - # TODO: refactor this, enable constant propagation for each (union-split) signature - match = applicable[nonbot]::MethodMatch - const_rettype, result = abstract_call_method_with_const_args(interp, rettype, f, argtypes, applicable[nonbot]::MethodMatch, sv, edgecycle) - const_conditional = ignorelimited(const_rettype) - @assert !(const_conditional isa Conditional) "invalid lattice element returned from inter-procedural context" - const_rettype = widenwrappedconditional(const_rettype) - if ignorelimited(const_rettype) ⊑ rettype - # use the better result, if it is a refinement of rettype - rettype = const_rettype - if const_conditional isa InterConditional && conditionals === nothing && fargs !== nothing - arg = fargs[const_conditional.slot] - if arg isa Slot - rettype = Conditional(arg, const_conditional.vtype, const_conditional.elsetype) - if const_rettype isa LimitedAccuracy - rettype = LimitedAccuracy(rettype, const_rettype.causes) - end - end - end - end - if result !== nothing - info = ConstCallInfo(info, result) - end - # and update refinements with the InterConditional info too - # (here we ignorelimited, since there isn't much below this in the - # lattice, particularly when we're already using tmeet) - if const_conditional isa InterConditional && conditionals !== nothing - let i = const_conditional.slot, - vtype = const_conditional.vtype, - elsetype = const_conditional.elsetype - if !(vtype ⊑ conditionals[1][i]) - vtype = tmeet(conditionals[1][i], widenconst(vtype)) - end - if !(elsetype ⊑ conditionals[2][i]) - elsetype = tmeet(conditionals[2][i], widenconst(elsetype)) - end - conditionals[1][i] = vtype - conditionals[2][i] = elsetype - end + if bail_out_call(interp, rettype, sv) + break end end + + # inliner uses this information only when there is a single match that has been improved + # by constant analysis, but let's create `ConstCallInfo` if there has been any successful + # constant propagation happened since other consumers may be interested in this + if any_const_result && seen == napplicable + info = ConstCallInfo(info, const_results) + end + if rettype isa LimitedAccuracy union!(sv.pclimitations, rettype.causes) rettype = rettype.typ @@ -284,7 +262,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end end @assert !(rettype isa InterConditional) "invalid lattice element returned from inter-procedural context" - if is_unused && !(rettype === Bottom) + + if call_result_unused(sv) && !(rettype === Bottom) add_remark!(interp, sv, "Call result type was widened because the return value is unused") # We're mainly only here because the optimizer might want this code, # but we ourselves locally don't typically care about it locally @@ -330,160 +309,6 @@ function add_call_backedges!(interp::AbstractInterpreter, end end -function const_prop_profitable(@nospecialize(arg)) - # have new information from argtypes that wasn't available from the signature - if isa(arg, PartialStruct) - for b in arg.fields - isconstType(b) && return true - const_prop_profitable(b) && return true - end - end - isa(arg, PartialOpaque) && return true - isa(arg, Const) || return true - val = arg.val - # don't consider mutable values or Strings useful constants - return isa(val, Symbol) || isa(val, Type) || (!isa(val, String) && !ismutable(val)) -end - -# This is a heuristic to avoid trying to const prop through complicated functions -# where we would spend a lot of time, but are probably unliekly to get an improved -# result anyway. -function const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance) - if method.is_for_opaque_closure - # Not inlining an opaque closure can be very expensive, so be generous - # with the const-prop-ability. It is quite possible that we can't infer - # anything at all without const-propping, so the inlining check below - # isn't particularly helpful here. - return true - end - # Peek at the inferred result for the function to determine if the optimizer - # was able to cut it down to something simple (inlineable in particular). - # If so, there's a good chance we might be able to const prop all the way - # through and learn something new. - code = get(code_cache(interp), mi, nothing) - declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source) - cache_inlineable = declared_inline - if isdefined(code, :inferred) && !cache_inlineable - cache_inf = code.inferred - if !(cache_inf === nothing) - cache_inlineable = inlining_policy(interp)(cache_inf) !== nothing - end - end - if !cache_inlineable - return false - end - return true -end - -function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, sv::InferenceState, edgecycle::Bool) - method = match.method - nargs::Int = method.nargs - method.isva && (nargs -= 1) - if length(argtypes) < nargs - return Any, nothing - end - haveconst = false - allconst = true - # see if any or all of the arguments are constant and propagating constants may be worthwhile - for a in argtypes - a = widenconditional(a) - if allconst && !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) && !isa(a, PartialOpaque) - allconst = false - end - if !haveconst && has_nontrivial_const_info(a) && const_prop_profitable(a) - haveconst = true - end - if haveconst && !allconst - break - end - end - haveconst || improvable_via_constant_propagation(rettype) || return Any, nothing - force_inference = method.aggressive_constprop || InferenceParams(interp).aggressive_constant_propagation - if !force_inference && nargs > 1 - if istopfunction(f, :getindex) || istopfunction(f, :setindex!) - arrty = argtypes[2] - # don't propagate constant index into indexing of non-constant array - if arrty isa Type && arrty <: AbstractArray && !issingletontype(arrty) - return Any, nothing - elseif arrty ⊑ Array - return Any, nothing - end - elseif istopfunction(f, :iterate) - itrty = argtypes[2] - if itrty ⊑ Array - return Any, nothing - end - end - end - if !force_inference && !allconst && - (istopfunction(f, :+) || istopfunction(f, :-) || istopfunction(f, :*) || - istopfunction(f, :(==)) || istopfunction(f, :!=) || - istopfunction(f, :<=) || istopfunction(f, :>=) || istopfunction(f, :<) || istopfunction(f, :>) || - istopfunction(f, :<<) || istopfunction(f, :>>)) - # it is almost useless to inline the op of when all the same type, - # but highly worthwhile to inline promote of a constant - length(argtypes) > 2 || return Any, nothing - t1 = widenconst(argtypes[2]) - all_same = true - for i in 3:length(argtypes) - if widenconst(argtypes[i]) !== t1 - all_same = false - break - end - end - all_same && return Any, nothing - end - if istopfunction(f, :getproperty) || istopfunction(f, :setproperty!) - force_inference = true - end - force_inference |= allconst - mi = specialize_method(match, !force_inference) - if mi === nothing - add_remark!(interp, sv, "[constprop] Failed to specialize") - return Any, nothing - end - mi = mi::MethodInstance - # decide if it's likely to be worthwhile - if !force_inference && !const_prop_heuristic(interp, method, mi) - add_remark!(interp, sv, "[constprop] Disabled by heuristic") - return Any, nothing - end - inf_cache = get_inference_cache(interp) - inf_result = cache_lookup(mi, argtypes, inf_cache) - if inf_result === nothing - if edgecycle - # if there might be a cycle, check to make sure we don't end up - # calling ourselves here. - infstate = sv - cyclei = 0 - while !(infstate === nothing) - if method === infstate.linfo.def && any(infstate.result.overridden_by_const) - add_remark!(interp, sv, "[constprop] Edge cycle encountered") - return Any, nothing - end - if cyclei < length(infstate.callers_in_cycle) - cyclei += 1 - infstate = infstate.callers_in_cycle[cyclei] - else - cyclei = 0 - infstate = infstate.parent - end - end - end - inf_result = InferenceResult(mi, argtypes) - frame = InferenceState(inf_result, #=cache=#false, interp) - frame === nothing && return Any, nothing # this is probably a bad generated function (unsound), but just ignore it - frame.parent = sv - push!(inf_cache, inf_result) - typeinf(interp, frame) || return Any, nothing - end - result = inf_result.result - # if constant inference hits a cycle, just bail out - isa(result, InferenceState) && return Any, nothing - add_backedge!(inf_result.linfo, sv) - return result, inf_result -end - const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result." function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState) @@ -643,6 +468,197 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp return rt, edgecycle, edge end +function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), + @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, + sv::InferenceState, edgecycle::Bool) + mi = maybe_get_const_prop_profitable(interp, rettype, f, argtypes, match, sv, edgecycle) + mi === nothing && return Any, nothing + # try constant prop' + inf_cache = get_inference_cache(interp) + inf_result = cache_lookup(mi, argtypes, inf_cache) + if inf_result === nothing + if edgecycle + # if there might be a cycle, check to make sure we don't end up + # calling ourselves here. + infstate = sv + cyclei = 0 + while !(infstate === nothing) + if match.method === infstate.linfo.def && any(infstate.result.overridden_by_const) + add_remark!(interp, sv, "[constprop] Edge cycle encountered") + return Any, nothing + end + if cyclei < length(infstate.callers_in_cycle) + cyclei += 1 + infstate = infstate.callers_in_cycle[cyclei] + else + cyclei = 0 + infstate = infstate.parent + end + end + end + inf_result = InferenceResult(mi, argtypes) + frame = InferenceState(inf_result, #=cache=#false, interp) + frame === nothing && return Any, nothing # this is probably a bad generated function (unsound), but just ignore it + frame.parent = sv + push!(inf_cache, inf_result) + typeinf(interp, frame) || return Any, nothing + end + result = inf_result.result + # if constant inference hits a cycle, just bail out + isa(result, InferenceState) && return Any, nothing + add_backedge!(mi, sv) + return result, inf_result +end + +# if there's a possibility we could get a better result (hopefully without doing too much work) +# returns `MethodInstance` with constant arguments, returns nothing otherwise +function maybe_get_const_prop_profitable(interp::AbstractInterpreter, @nospecialize(rettype), + @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, + sv::InferenceState, edgecycle::Bool) + const_prop_entry_heuristic(interp, rettype, sv, edgecycle) || return nothing + method = match.method + nargs::Int = method.nargs + method.isva && (nargs -= 1) + if length(argtypes) < nargs + return nothing + end + const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, rettype) || return nothing + allconst = is_allconst(argtypes) + force = force_const_prop(interp, f, method) + force || const_prop_function_heuristic(interp, f, argtypes, nargs, allconst) || return nothing + force |= allconst + mi = specialize_method(match, !force) + if mi === nothing + add_remark!(interp, sv, "[constprop] Failed to specialize") + return nothing + end + mi = mi::MethodInstance + if !force && !const_prop_methodinstance_heuristic(interp, method, mi) + add_remark!(interp, sv, "[constprop] Disabled by heuristic") + return nothing + end + return mi +end + +function const_prop_entry_heuristic(interp::AbstractInterpreter, @nospecialize(rettype), sv::InferenceState, edgecycle::Bool) + call_result_unused(sv) && edgecycle && return false + return is_improvable(rettype) && InferenceParams(interp).ipo_constant_propagation +end + +# see if propagating constants may be worthwhile +function const_prop_argument_heuristic(interp::AbstractInterpreter, argtypes::Vector{Any}) + for a in argtypes + a = widenconditional(a) + if has_nontrivial_const_info(a) && is_const_prop_profitable_arg(a) + return true + end + end + return false +end + +function is_const_prop_profitable_arg(@nospecialize(arg)) + # have new information from argtypes that wasn't available from the signature + if isa(arg, PartialStruct) + for b in arg.fields + isconstType(b) && return true + is_const_prop_profitable_arg(b) && return true + end + end + isa(arg, PartialOpaque) && return true + isa(arg, Const) || return true + val = arg.val + # don't consider mutable values or Strings useful constants + return isa(val, Symbol) || isa(val, Type) || (!isa(val, String) && !ismutable(val)) +end + +function const_prop_rettype_heuristic(interp::AbstractInterpreter, @nospecialize(rettype)) + return improvable_via_constant_propagation(rettype) +end + +function is_allconst(argtypes::Vector{Any}) + for a in argtypes + a = widenconditional(a) + if !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) && !isa(a, PartialOpaque) + return false + end + end + return true +end + +function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method::Method) + return method.aggressive_constprop || + InferenceParams(interp).aggressive_constant_propagation || + istopfunction(f, :getproperty) || + istopfunction(f, :setproperty!) +end + +function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, nargs::Int, allconst::Bool) + if nargs > 1 + if istopfunction(f, :getindex) || istopfunction(f, :setindex!) + arrty = argtypes[2] + # don't propagate constant index into indexing of non-constant array + if arrty isa Type && arrty <: AbstractArray && !issingletontype(arrty) + return false + elseif arrty ⊑ Array + return false + end + elseif istopfunction(f, :iterate) + itrty = argtypes[2] + if itrty ⊑ Array + return false + end + end + end + if !allconst && (istopfunction(f, :+) || istopfunction(f, :-) || istopfunction(f, :*) || + istopfunction(f, :(==)) || istopfunction(f, :!=) || + istopfunction(f, :<=) || istopfunction(f, :>=) || istopfunction(f, :<) || istopfunction(f, :>) || + istopfunction(f, :<<) || istopfunction(f, :>>)) + # it is almost useless to inline the op when all the same type, + # but highly worthwhile to inline promote of a constant + length(argtypes) > 2 || return false + t1 = widenconst(argtypes[2]) + all_same = true + for i in 3:length(argtypes) + if widenconst(argtypes[i]) !== t1 + all_same = false + break + end + end + return !all_same + end + return true +end + +# This is a heuristic to avoid trying to const prop through complicated functions +# where we would spend a lot of time, but are probably unlikely to get an improved +# result anyway. +function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance) + if method.is_for_opaque_closure + # Not inlining an opaque closure can be very expensive, so be generous + # with the const-prop-ability. It is quite possible that we can't infer + # anything at all without const-propping, so the inlining check below + # isn't particularly helpful here. + return true + end + # Peek at the inferred result for the function to determine if the optimizer + # was able to cut it down to something simple (inlineable in particular). + # If so, there's a good chance we might be able to const prop all the way + # through and learn something new. + code = get(code_cache(interp), mi, nothing) + declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source) + cache_inlineable = declared_inline + if isdefined(code, :inferred) && !cache_inlineable + cache_inf = code.inferred + if !(cache_inf === nothing) + cache_inlineable = inlining_policy(interp)(cache_inf) !== nothing + end + end + if !cache_inlineable + return false + end + return true +end + # This is only for use with `Conditional`. # In general, usage of this is wrong. function ssa_def_slot(@nospecialize(arg), sv::InferenceState) @@ -1224,7 +1240,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part rt = const_rettype end if result !== nothing - info = ConstCallInfo(info, result) + info = ConstCallInfo(info, Union{Nothing,InferenceResult}[result]) end end return CallMeta(rt, info) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 34b966a5ce9a5..b75172656dd2f 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -634,12 +634,19 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx:: new_stmt = Expr(:call, argexprs[2], def, state...) state1 = insert_node!(ir, idx, call.rt, new_stmt) new_sig = with_atype(call_sig(ir, new_stmt)::Signature) - if isa(call.info, ConstCallInfo) - handle_const_call!(ir, state1.id, new_stmt, call.info, new_sig, + info = call.info + handled = false + if isa(info, ConstCallInfo) + if maybe_handle_const_call!(ir, state1.id, new_stmt, info, new_sig, call.rt, istate, false, todo) - elseif isa(call.info, MethodMatchInfo) || isa(call.info, UnionSplitInfo) - info = isa(call.info, MethodMatchInfo) ? - MethodMatchInfo[call.info] : call.info.matches + handled = true + else + info = info.call + end + end + if !handled && (isa(info, MethodMatchInfo) || isa(info, UnionSplitInfo)) + info = isa(info, MethodMatchInfo) ? + MethodMatchInfo[info] : info.matches # See if we can inline this call to `iterate` analyze_single_call!(ir, todo, state1.id, new_stmt, new_sig, call.rt, info, istate) @@ -736,7 +743,8 @@ function resolve_todo(todo::InliningTodo, state::InliningState) src = copy(src) end - state.et !== nothing && push!(state.et, todo.mi) + et = state.et + et !== nothing && push!(et, todo.mi) return InliningTodo(todo.mi, src) end @@ -1164,25 +1172,33 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int return nothing end -function handle_const_call!(ir::IRCode, idx::Int, stmt::Expr, +function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, sig::Signature, @nospecialize(calltype), state::InliningState, isinvoke::Bool, todo::Vector{Pair{Int, Any}}) - item = InliningTodo(info.result, sig.atypes, calltype) - validate_sparams(item.mi.sparam_vals) || return + # when multiple matches are found, bail out and later inliner will union-split this signature + # TODO effectively use multiple constant analysis results here + length(info.results) == 1 || return false + result = info.results[1] + isa(result, InferenceResult) || return false + + item = InliningTodo(result, sig.atypes, calltype) + validate_sparams(item.mi.sparam_vals) || return true mthd_sig = item.mi.def.sig mistypes = item.mi.specTypes state.mi_cache !== nothing && (item = resolve_todo(item, state)) if sig.atype <: mthd_sig - return handle_single_case!(ir, stmt, idx, item, isinvoke, todo) + handle_single_case!(ir, stmt, idx, item, isinvoke, todo) + return true else - item === nothing && return + item === nothing && return true # Union split out the error case item = UnionSplit(false, sig.atype, Pair{Any, Any}[mistypes => item]) if isinvoke stmt.args = rewrite_invoke_exprargs!(stmt.args) end push!(todo, idx=>item) + return true end end @@ -1216,9 +1232,11 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) # it'll have performed a specialized analysis for just this case. Use its # result. if isa(info, ConstCallInfo) - handle_const_call!(ir, idx, stmt, info, sig, calltype, state, - sig.f === Core.invoke, todo) - continue + if maybe_handle_const_call!(ir, idx, stmt, info, sig, calltype, state, sig.f === Core.invoke, todo) + continue + else + info = info.call + end end # Handle invoke diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index 73442001884e3..7553b9a4394b6 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -97,7 +97,7 @@ constant information. """ struct ConstCallInfo call::Any - result::InferenceResult + results::Vector{Union{Nothing,InferenceResult}} end """ diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index f290608f5b8ad..fc9282394b91a 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -184,10 +184,16 @@ function switchtupleunion(@nospecialize(ty)) return _switchtupleunion(Any[tparams...], length(tparams), [], ty) end +switchtupleunion(argtypes::Vector{Any}) = _switchtupleunion(argtypes, length(argtypes), [], nothing) + function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt)) if i == 0 - tpl = rewrap_unionall(Tuple{t...}, origt) - push!(tunion, tpl) + if origt === nothing + push!(tunion, copy(t)) + else + tpl = rewrap_unionall(Tuple{t...}, origt) + push!(tunion, tpl) + end else ti = t[i] if isa(ti, Union) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 538d226e885f5..1a5ac13684c29 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3106,3 +3106,99 @@ end == [Int] let f() = Val(fieldnames(Complex{Int})) @test @inferred(f()) === Val((:re,:im)) end + +@testset "switchtupleunion" begin + # signature tuple + let + tunion = Core.Compiler.switchtupleunion(Tuple{Union{Int32,Int64}, Nothing}) + @test Tuple{Int32, Nothing} in tunion + @test Tuple{Int64, Nothing} in tunion + end + let + tunion = Core.Compiler.switchtupleunion(Tuple{Union{Int32,Int64}, Union{Float32,Float64}, Nothing}) + @test Tuple{Int32, Float32, Nothing} in tunion + @test Tuple{Int32, Float64, Nothing} in tunion + @test Tuple{Int64, Float32, Nothing} in tunion + @test Tuple{Int64, Float64, Nothing} in tunion + end + + # argtypes + let + tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Core.Const(nothing)]) + @test length(tunion) == 2 + @test Any[Int32, Core.Const(nothing)] in tunion + @test Any[Int64, Core.Const(nothing)] in tunion + end + let + tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)]) + @test length(tunion) == 4 + @test Any[Int32, Float32, Core.Const(nothing)] in tunion + @test Any[Int32, Float64, Core.Const(nothing)] in tunion + @test Any[Int64, Float32, Core.Const(nothing)] in tunion + @test Any[Int64, Float64, Core.Const(nothing)] in tunion + end +end + +@testset "constant prop' for union split signature" begin + anonymous_module() = Core.eval(@__MODULE__, :(module $(gensym()) end))::Module + + # indexing into tuples really relies on constant prop', and we will get looser result + # (`Union{Int,String,Char}`) if constant prop' doesn't happen for splitunion signatures + tt = (Union{Tuple{Int,String},Tuple{Int,Char}},) + @test Base.return_types(tt) do t + getindex(t, 1) + end == Any[Int] + @test Base.return_types(tt) do t + getindex(t, 2) + end == Any[Union{String,Char}] + @test Base.return_types(tt) do t + a, b = t + a + end == Any[Int] + @test Base.return_types(tt) do t + a, b = t + b + end == Any[Union{String,Char}] + + @test (@eval anonymous_module() begin + struct F32 + val::Float32 + _v::Int + end + struct F64 + val::Float64 + _v::Int + end + Base.return_types((Union{F32,F64},)) do f + f.val + end + end) == Any[Union{Float32,Float64}] + + @test (@eval anonymous_module() begin + struct F32 + val::Float32 + _v + end + struct F64 + val::Float64 + _v + end + Base.return_types((Union{F32,F64},)) do f + f.val + end + end) == Any[Union{Float32,Float64}] + + @test Base.return_types((Union{Tuple{Nothing,Any,Any},Tuple{Nothing,Any}},)) do t + getindex(t, 1) + end == Any[Nothing] + + # issue #37610 + @test Base.return_types((typeof(("foo" => "bar", "baz" => nothing)), Int)) do a, i + y = iterate(a, i) + if y !== nothing + (k, v), st = y + return k, v + end + return y + end == Any[Union{Nothing, Tuple{String, Union{Nothing, String}}}] +end