From b8ea05cf8a87476ccf3716083d3d08f2e094f83a Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Thu, 6 Jan 2022 01:18:39 +0900 Subject: [PATCH] optimizer: fully support inlining of union-split, partially constant-prop' callsite (#43347) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Makes full use of constant-propagation, by addressing this [TODO](https://github.com/JuliaLang/julia/blob/00734c5fd045316a00d287ca2c0ec1a2eef6e4d1/base/compiler/ssair/inlining.jl#L1212). Here is a performance improvement from #43287: ```julia ulia> using BenchmarkTools julia> X = rand(ComplexF32, 64, 64); julia> dst = reinterpret(reshape, Float32, X); julia> src = copy(dst); julia> @btime copyto!($dst, $src); 50.819 μs (1 allocation: 32 bytes) # v1.6.4 41.081 μs (0 allocations: 0 bytes) # this commit ``` fixes #43287 --- base/compiler/abstractinterpretation.jl | 1 + base/compiler/ssair/inlining.jl | 190 ++++++++++++------------ base/compiler/stmtinfo.jl | 33 ++-- test/compiler/inline.jl | 24 +-- 4 files changed, 133 insertions(+), 115 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 966e0ee54be5e3..dc7b4489cb0f5c 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -156,6 +156,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # by constant analysis, but let's create `ConstCallInfo` if there has been any successful # constant propagation happened since other consumers may be interested in this if any_const_result && seen == napplicable + @assert napplicable == nmatches(info) == length(const_results) info = ConstCallInfo(info, const_results) end diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index d454eeb1593b39..00c853ec207b2c 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -675,24 +675,17 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx:: new_stmt = Expr(:call, argexprs[2], def, state...) state1 = insert_node!(ir, idx, NewInstruction(new_stmt, call.rt)) new_sig = with_atype(call_sig(ir, new_stmt)::Signature) - info = call.info - handled = false - if isa(info, ConstCallInfo) - if maybe_handle_const_call!( - ir, state1.id, new_stmt, info, new_sig, - istate, false, todo) - handled = true - else - info = info.call - end - end - if !handled && (isa(info, MethodMatchInfo) || isa(info, UnionSplitInfo)) - info = isa(info, MethodMatchInfo) ? - MethodMatchInfo[info] : info.matches + new_info = call.info + if isa(new_info, ConstCallInfo) + handle_const_call!( + ir, state1.id, new_stmt, new_info, + new_sig, istate, todo) + elseif isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo) + new_infos = isa(new_info, MethodMatchInfo) ? MethodMatchInfo[new_info] : new_info.matches # See if we can inline this call to `iterate` analyze_single_call!( ir, todo, state1.id, new_stmt, - new_sig, info, istate) + new_sig, new_infos, istate) end if i != length(thisarginfo.each) valT = getfield_tfunc(call.rt, Const(1)) @@ -910,7 +903,9 @@ function iterate(split::UnionSplitSignature, state::Vector{Int}...) return (sig, state) end -function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Pair{Int, Any}}) +function handle_single_case!( + ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), + todo::Vector{Pair{Int, Any}}, isinvoke::Bool = false) if isa(case, ConstantCase) ir[SSAValue(idx)] = case.val elseif isa(case, MethodInstance) @@ -1086,13 +1081,13 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result): validate_sparams(mi.sparam_vals) || return nothing if argtypes_to_type(atypes) <: mi.def.sig state.mi_cache !== nothing && (item = resolve_todo(item, state)) - handle_single_case!(ir, stmt, idx, item, true, todo) + handle_single_case!(ir, stmt, idx, item, todo, true) return nothing end end result = analyze_method!(match, atypes, state) - handle_single_case!(ir, stmt, idx, result, true, todo) + handle_single_case!(ir, stmt, idx, result, todo, true) return nothing end @@ -1200,49 +1195,39 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta return sig end -# TODO inline non-`isdispatchtuple`, union-split callsites +# TODO inline non-`isdispatchtuple`, union-split callsites? function analyze_single_call!( ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt), - (; atypes, atype)::Signature, infos::Vector{MethodMatchInfo}, state::InliningState) + sig::Signature, infos::Vector{MethodMatchInfo}, state::InliningState) + (; atypes, atype) = sig cases = InliningCase[] local signature_union = Bottom local only_method = nothing # keep track of whether there is one matching method - local meth + local meth::MethodLookupResult local fully_covered = true for i in 1:length(infos) - info = infos[i] - meth = info.results + meth = infos[i].results if meth.ambig # Too many applicable methods # Or there is a (partial?) ambiguity - return + return nothing elseif length(meth) == 0 # No applicable methods; try next union split continue - elseif length(meth) == 1 && only_method !== false - if only_method === nothing - only_method = meth[1].method - elseif only_method !== meth[1].method + else + if length(meth) == 1 && only_method !== false + if only_method === nothing + only_method = meth[1].method + elseif only_method !== meth[1].method + only_method = false + end + else only_method = false end - else - only_method = false end for match in meth - spec_types = match.spec_types - signature_union = Union{signature_union, spec_types} - if !isdispatchtuple(spec_types) - fully_covered = false - continue - end - item = analyze_method!(match, atypes, state) - if item === nothing - fully_covered = false - continue - elseif _any(case->case.sig === spec_types, cases) - continue - end - push!(cases, InliningCase(spec_types, item)) + signature_union = Union{signature_union, match.spec_types} + fully_covered &= handle_match!(match, atypes, state, cases) end end @@ -1253,9 +1238,8 @@ function analyze_single_call!( if length(infos) > 1 (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), atype, only_method.sig)::SimpleVector - match = MethodMatch(metharg, methsp, only_method, true) + match = MethodMatch(metharg, methsp::SimpleVector, only_method, true) else - meth = meth::MethodLookupResult @assert length(meth) == 1 match = meth[1] end @@ -1268,46 +1252,41 @@ function analyze_single_call!( fully_covered = false end - # If we only have one case and that case is fully covered, we may either - # be able to do the inlining now (for constant cases), or push it directly - # onto the todo list - if fully_covered && length(cases) == 1 - handle_single_case!(ir, stmt, idx, cases[1].item, false, todo) - elseif length(cases) > 0 - push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) - end - return nothing + handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo) end -# try to create `InliningCase`s using constant-prop'ed results -# currently it works only when constant-prop' succeeded for all (union-split) signatures -# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later -# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out -function maybe_handle_const_call!( - ir::IRCode, idx::Int, stmt::Expr, (; results)::ConstCallInfo, (; atypes, atype)::Signature, - state::InliningState, isinvoke::Bool, todo::Vector{Pair{Int, Any}}) - cases = InliningCase[] # TODO avoid this allocation for single cases ? +# similar to `analyze_single_call!`, but with constant results +function handle_const_call!( + ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo, + sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}}) + (; atypes, atype) = sig + (; call, results) = cinfo + infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches + cases = InliningCase[] local fully_covered = true local signature_union = Bottom - for result in results - isa(result, InferenceResult) || return false - (; mi) = item = InliningTodo(result, atypes) - spec_types = mi.specTypes - signature_union = Union{signature_union, spec_types} - if !isdispatchtuple(spec_types) - fully_covered = false - continue - end - if !validate_sparams(mi.sparam_vals) - fully_covered = false + local j = 0 + for i in 1:length(infos) + meth = infos[i].results + if meth.ambig + # Too many applicable methods + # Or there is a (partial?) ambiguity + return nothing + elseif length(meth) == 0 + # No applicable methods; try next union split continue end - state.mi_cache !== nothing && (item = resolve_todo(item, state)) - if item === nothing - fully_covered = false - continue + for match in meth + j += 1 + result = results[j] + if result === nothing + signature_union = Union{signature_union, match.spec_types} + fully_covered &= handle_match!(match, atypes, state, cases) + else + signature_union = Union{signature_union, result.linfo.specTypes} + fully_covered &= handle_const_result!(result, atypes, state, cases) + end end - push!(cases, InliningCase(spec_types, item)) end # if the signature is fully covered and there is only one applicable method, @@ -1316,8 +1295,8 @@ function maybe_handle_const_call!( if length(cases) == 0 && length(results) == 1 (; mi) = item = InliningTodo(results[1]::InferenceResult, atypes) state.mi_cache !== nothing && (item = resolve_todo(item, state)) - validate_sparams(mi.sparam_vals) || return true - item === nothing && return true + validate_sparams(mi.sparam_vals) || return nothing + item === nothing && return nothing push!(cases, InliningCase(mi.specTypes, item)) fully_covered = true end @@ -1325,16 +1304,45 @@ function maybe_handle_const_call!( fully_covered = false end + handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo) +end + +function handle_match!( + match::MethodMatch, argtypes::Vector{Any}, state::InliningState, + cases::Vector{InliningCase}) + spec_types = match.spec_types + isdispatchtuple(spec_types) || return false + item = analyze_method!(match, argtypes, state) + item === nothing && return false + _any(case->case.sig === spec_types, cases) && return true + push!(cases, InliningCase(spec_types, item)) + return true +end + +function handle_const_result!( + result::InferenceResult, argtypes::Vector{Any}, state::InliningState, + cases::Vector{InliningCase}) + (; mi) = item = InliningTodo(result, argtypes) + spec_types = mi.specTypes + isdispatchtuple(spec_types) || return false + validate_sparams(mi.sparam_vals) || return false + state.mi_cache !== nothing && (item = resolve_todo(item, state)) + item === nothing && return false + push!(cases, InliningCase(spec_types, item)) + return true +end + +function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, sig::Signature, + cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}}) # If we only have one case and that case is fully covered, we may either # be able to do the inlining now (for constant cases), or push it directly # onto the todo list if fully_covered && length(cases) == 1 - handle_single_case!(ir, stmt, idx, cases[1].item, isinvoke, todo) + handle_single_case!(ir, stmt, idx, cases[1].item, todo) elseif length(cases) > 0 - isinvoke && rewrite_invoke_exprargs!(stmt) - push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) + push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases)) end - return true + return nothing end function handle_const_opaque_closure_call!( @@ -1346,7 +1354,7 @@ function handle_const_opaque_closure_call!( isdispatchtuple(item.mi.specTypes) || return validate_sparams(item.mi.sparam_vals) || return state.mi_cache !== nothing && (item = resolve_todo(item, state)) - handle_single_case!(ir, stmt, idx, item, false, todo) + handle_single_case!(ir, stmt, idx, item, todo) return nothing end @@ -1371,9 +1379,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE info = info.info end - - # Inference determined this couldn't be analyzed. Don't question it. if info === false + # Inference determined this couldn't be analyzed. Don't question it. continue end @@ -1386,16 +1393,15 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) sig, state, todo) continue else - maybe_handle_const_call!( + handle_const_call!( ir, idx, stmt, info, sig, - state, sig.f === Core.invoke, todo) && continue + state, todo) end - info = info.call # cascade to the non-constant handling end if isa(info, OpaqueClosureCallInfo) item = analyze_method!(info.match, sig.atypes, state) - handle_single_case!(ir, stmt, idx, item, false, todo) + handle_single_case!(ir, stmt, idx, item, todo) continue end diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index 0c54e9359fa1a2..bc6adcaefa7a86 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -40,6 +40,27 @@ struct UnionSplitInfo matches::Vector{MethodMatchInfo} end +nmatches(info::MethodMatchInfo) = length(info.results) +function nmatches(info::UnionSplitInfo) + n = 0 + for mminfo in info.matches + n += nmatches(mminfo) + end + return n +end + +""" + info::ConstCallInfo + +The precision of this call was improved using constant information. +In addition to the original call information `info.call`, this info also keeps +the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`. +""" +struct ConstCallInfo + call::Union{MethodMatchInfo,UnionSplitInfo} + results::Vector{Union{Nothing,InferenceResult}} +end + """ struct CallMeta @@ -88,18 +109,6 @@ struct UnionSplitApplyCallInfo infos::Vector{ApplyCallInfo} end -""" - struct ConstCallInfo - -Precision for this call was improved using constant information. This info -keeps a reference to the result that was used (or created for these) -constant information. -""" -struct ConstCallInfo - call::Any - results::Vector{Union{Nothing,InferenceResult}} -end - """ struct InvokeCallInfo diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index f157f78a433e7b..b3901a02df35fe 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -438,17 +438,19 @@ end import Base: @constprop # test union-split callsite with successful and unsuccessful constant-prop' results -@constprop :aggressive @inline f42840(xs, a::Int) = xs[a] # should be successful, and inlined -@constprop :none @noinline f42840(xs::AbstractVector, a::Int) = xs[a] # should be unsuccessful, but still statically resolved -let src = code_typed1((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs - f42840(xs, 2) - end - @test count(src.code) do @nospecialize x - iscall((src, getfield), x) # `(xs::Tuple{Int,Int,Int})[a::Const(2)]` => `getfield(xs, 2)` - end == 1 - @test count(src.code) do @nospecialize x - isinvoke(:f42840, x) - end == 1 +# (also for https://github.com/JuliaLang/julia/issues/43287) +@constprop :aggressive @inline f42840(cond::Bool, xs::Tuple, a::Int) = # should be successful, and inlined with constant prop' result + cond ? xs[a] : @noinline(length(xs)) +@constprop :none @noinline f42840(::Bool, xs::AbstractVector, a::Int) = # should be unsuccessful, but still statically resolved + xs[a] +let src = code_typed((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs + f42840(true, xs, 2) + end |> only |> first + # `f43287(true, xs::Tuple{Int,Int,Int}, 2)` => `getfield(xs, 2)` + # `f43287(true, xs::Vector{Int}, 2)` => `:invoke f43287(true, xs, 2)` + @test count(iscall((src, getfield)), src.code) == 1 + @test count(isinvoke(:length), src.code) == 0 + @test count(isinvoke(:f42840), src.code) == 1 end # a bit weird, but should handle this kind of case as well @constprop :aggressive @noinline g42840(xs, a::Int) = xs[a] # should be successful, but only statically resolved