Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow inlining methods with unmatched type parameters #45062

Merged
merged 1 commit into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 117 additions & 26 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,28 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
boundscheck = :off
end
end
if !validate_sparams(sparam_vals)
if def.isva
nonva_args = argexprs[1:end-1]
va_arg = argexprs[end]
tuple_call = Expr(:call, TOP_TUPLE, def, nonva_args...)
tuple_type = tuple_tfunc(Any[argextype(arg, compact) for arg in nonva_args])
tupl = insert_node_here!(compact, NewInstruction(tuple_call, tuple_type, topline))
apply_iter_expr = Expr(:call, Core._apply_iterate, iterate, Core._compute_sparams, tupl, va_arg)
aviatesk marked this conversation as resolved.
Show resolved Hide resolved
sparam_vals = insert_node_here!(compact,
effect_free(NewInstruction(apply_iter_expr, SimpleVector, topline)))
Comment on lines +381 to +388
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like it would add significant cost (and allocations and new DataType's) that were not present before inlining? are we sure this is profitable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prior to the fix_va_argexprs! call that you have above, the argexprs list was required to be a known length and did not have or need a va_arg

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in #46700.

else
sparam_vals = insert_node_here!(compact,
effect_free(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline)))
end
end
# If the iterator already moved on to the next basic block,
# temporarily re-open in again.
local return_value
sig = def.sig
# Special case inlining that maintains the current basic block if there's only one BB in the target
new_new_offset = length(compact.new_new_nodes)
late_fixup_offset = length(compact.late_fixup)
if spec.linear_inline_eligible
#compact[idx] = nothing
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
Expand All @@ -389,7 +406,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
# face of rename_arguments! mutating in place - should figure out
# something better eventually.
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
if isa(stmt′, ReturnNode)
val = stmt′.val
return_value = SSAValue(idx′)
Expand All @@ -402,7 +419,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
end
inline_compact[idx′] = stmt′
end
just_fixup!(inline_compact)
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
compact.result_idx = inline_compact.result_idx
else
bb_offset, post_bb_id = popfirst!(todo_bbs)
Expand All @@ -416,7 +433,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
for ((_, idx′), stmt′) in inline_compact
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
if isa(stmt′, ReturnNode)
if isdefined(stmt′, :val)
val = stmt′.val
Expand All @@ -436,7 +453,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
end
inline_compact[idx′] = stmt′
end
just_fixup!(inline_compact)
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
compact.result_idx = inline_compact.result_idx
compact.active_result_bb = inline_compact.active_result_bb
if length(pn.edges) == 1
Expand All @@ -460,7 +477,8 @@ function fix_va_argexprs!(compact::IncrementalCompact,
push!(tuple_typs, argextype(arg, compact))
end
tuple_typ = tuple_tfunc(tuple_typs)
push!(newargexprs, insert_node_here!(compact, NewInstruction(tuple_call, tuple_typ, line_idx)))
tuple_inst = NewInstruction(tuple_call, tuple_typ, line_idx)
push!(newargexprs, insert_node_here!(compact, tuple_inst))
return newargexprs
end

Expand Down Expand Up @@ -875,8 +893,26 @@ function validate_sparams(sparams::SimpleVector)
return true
end

function may_have_fcalls(m::Method)
may_have_fcall = true
if isdefined(m, :source)
src = m.source
isa(src, Vector{UInt8}) && (src = uncompressed_ir(m))
if isa(src, CodeInfo)
may_have_fcall = src.has_fcall
end
end
return may_have_fcall
end

function can_inline_typevars(m::MethodMatch, argtypes::Vector{Any})
may_have_fcalls(m.method) && return false
any(@nospecialize(x) -> x isa UnionAll, argtypes[2:end]) && return false
return true
end

function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, invokesig,
flag::UInt8, state::InliningState)
flag::UInt8, state::InliningState, allow_typevars::Bool = false)
method = match.method
spec_types = match.spec_types

Expand All @@ -898,8 +934,9 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, invokesig,
end
end

# Bail out if any static parameters are left as TypeVar
validate_sparams(match.sparams) || return nothing
if !validate_sparams(match.sparams)
(allow_typevars && can_inline_typevars(match, argtypes)) || return nothing
end

et = state.et

Expand Down Expand Up @@ -1231,6 +1268,9 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo},
flag::UInt8, sig::Signature, state::InliningState)
argtypes = sig.argtypes
cases = InliningCase[]
local only_method = nothing
local meth::MethodLookupResult
local revisit_idx = nothing
local any_fully_covered = false
local handled_all_cases = true
for i in 1:length(infos)
Expand All @@ -1243,14 +1283,58 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo},
# No applicable methods; try next union split
handled_all_cases = false
continue
else
if length(meth) == 1 && only_method !== false
if only_method === nothing
only_method = meth[1].method
elseif only_method !== meth[1].method
only_method = false
end
else
only_method = false
end
end
for match in meth
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true)
for (j, match) in enumerate(meth)
any_fully_covered |= match.fully_covers
if !validate_sparams(match.sparams)
if !match.fully_covers
handled_all_cases = false
continue
end
if revisit_idx === nothing
revisit_idx = (i, j)
else
handled_all_cases = false
revisit_idx = nothing
end
else
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
end
end
end

if !handled_all_cases
if handled_all_cases && revisit_idx !== nothing
# we handled everything except one match with unmatched sparams,
# so try to handle it by bypassing validate_sparams
(i, j) = revisit_idx
match = infos[i].results[j]
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#true)
elseif length(cases) == 0 && only_method isa Method
# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even in the prescence of unmatched sparams
# -- But don't try it if we already tried to handle the match in the revisit_idx
# case, because that'll (necessarily) be the same method.
if length(infos) > 1
atype = argtypes_to_type(argtypes)
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), atype, only_method.sig)::SimpleVector
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
else
@assert length(meth) == 1
match = meth[1]
end
handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#true) || return nothing
any_fully_covered = handled_all_cases = match.fully_covers
elseif !handled_all_cases
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end
Expand Down Expand Up @@ -1286,10 +1370,10 @@ function compute_inlining_cases(info::ConstCallInfo,
case = concrete_result_item(result, state)
push!(cases, InliningCase(result.mi.specTypes, case))
elseif isa(result, ConstPropResult)
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, true)
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, #=allow_abstract=#true)
else
@assert result === nothing
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true)
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
end
end
end
Expand Down Expand Up @@ -1324,22 +1408,22 @@ end

function handle_match!(
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase}, allow_abstract::Bool = false)
cases::Vector{InliningCase}, allow_abstract::Bool, allow_typevars::Bool)
spec_types = match.spec_types
allow_abstract || isdispatchtuple(spec_types) || return false
# we may see duplicated dispatch signatures here when a signature gets widened
# We may see duplicated dispatch signatures here when a signature gets widened
# during abstract interpretation: for the purpose of inlining, we can just skip
# processing this dispatch candidate
_any(case->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, nothing, flag, state)
# processing this dispatch candidate (unless unmatched type parameters are present)
!allow_typevars && _any(case->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, nothing, flag, state, allow_typevars)
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
return true
end

function handle_const_prop_result!(
result::ConstPropResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase}, allow_abstract::Bool = false)
cases::Vector{InliningCase}, allow_abstract::Bool)
(; mi) = item = InliningTodo(result.result, argtypes)
spec_types = mi.specTypes
allow_abstract || isdispatchtuple(spec_types) || return false
Expand Down Expand Up @@ -1624,30 +1708,37 @@ function late_inline_special_case!(
end

function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any},
@nospecialize(spsig), spvals::SimpleVector,
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue},
linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact)
compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS
compact.result[idx][:line] += linetable_offset
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck)
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck, compact, idx)
end

function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
@nospecialize(spsig), spvals::SimpleVector, boundscheck::Symbol)
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue},
boundscheck::Symbol, compact::IncrementalCompact, idx::Int)
if isa(val, Argument)
return arg_replacements[val.n]
end
if isa(val, Expr)
e = val::Expr
head = e.head
if head === :static_parameter
return quoted(spvals[e.args[1]::Int])
elseif head === :cfunction
if isa(spvals, SimpleVector)
return quoted(spvals[e.args[1]::Int])
else
ret = insert_node!(compact, SSAValue(idx),
effect_free(NewInstruction(Expr(:call, Core._svec_ref, false, spvals, e.args[1]), Any)))
return ret
end
elseif head === :cfunction && isa(spvals, SimpleVector)
@assert !isa(spsig, UnionAll) || !isempty(spvals)
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
e.args[4] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[4]::SimpleVector ]...)
elseif head === :foreigncall
elseif head === :foreigncall && isa(spvals, SimpleVector)
@assert !isa(spsig, UnionAll) || !isempty(spvals)
Comment on lines +1741 to 1742
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be an assert? if it managed to get this far, but didn't run this fixup code here, it will generate corrupt code later and possibly segfault later at runtime

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming you're talking about the isa? If so, I don't think so - there can be :foreigncalls that got inlined from a callee that don't need the fixup here (though it's not harmful either).

for i = 1:length(e.args)
if i == 2
Expand All @@ -1671,7 +1762,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
isa(val, Union{SSAValue, NewSSAValue}) && return val # avoid infinite loop
urs = userefs(val)
for op in urs
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck)
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck, compact, idx)
end
return urs[]
end
Loading