Skip to content

Commit

Permalink
Fixups and use concrete spvals in ssa_substitute
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Atol committed May 9, 2022
1 parent e0a1113 commit 0d6db0e
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 45 deletions.
84 changes: 48 additions & 36 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
# face of rename_arguments! mutating in place - should figure out
# something better eventually.
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, item.mi.sparam_vals, linetable_offset, boundscheck, inline_compact)
if isa(stmt′, ReturnNode)
val = stmt′.val
return_value = SSAValue(idx′)
Expand All @@ -414,7 +414,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
for ((_, idx′), stmt′) in inline_compact
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, item.mi.sparam_vals, linetable_offset, boundscheck, inline_compact)
if isa(stmt′, ReturnNode)
if isdefined(stmt′, :val)
val = stmt′.val
Expand Down Expand Up @@ -892,7 +892,7 @@ function validate_sparams(sparams::SimpleVector)
end

function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
flag::UInt8, state::InliningState)
flag::UInt8, state::InliningState, check_sparams::Bool = false)
method = match.method
spec_types = match.spec_types

Expand All @@ -914,7 +914,7 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
end
end

#validate_sparams(match.sparams) || return nothing
check_sparams && (validate_sparams(match.sparams) || return nothing)

et = state.et

Expand Down Expand Up @@ -1121,7 +1121,7 @@ function inline_invoke!(
argtypes = invoke_rewrite(sig.argtypes)
if isa(result, ConstPropResult)
(; mi) = item = InliningTodo(result.result, argtypes)
# validate_sparams(mi.sparam_vals) || return nothing
validate_sparams(mi.sparam_vals) || return nothing
if argtypes_to_type(argtypes) <: mi.def.sig
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
handle_single_case!(ir, idx, stmt, item, todo, state.params, true)
Expand Down Expand Up @@ -1250,7 +1250,6 @@ function analyze_single_call!(
cases = InliningCase[]
local any_fully_covered = false
local handled_all_cases = true
local revisit_idx = nothing
local only_method = nothing # tri-valued: nothing if unknown, false if proven untrue, otherwise the method itself
local meth::MethodLookupResult
for i in 1:length(infos)
Expand All @@ -1263,9 +1262,21 @@ function analyze_single_call!(
# No applicable methods; try next union split
handled_all_cases = false
continue
else
if length(meth) == 1 && only_method !== false
if only_method === nothing
only_method = meth[1].method
elseif only_method !== meth[1].method
only_method = false
end
else
only_method = false
end
end

check_sparams = isa(only_method, Bool)
for match in meth
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true, check_sparams)
any_fully_covered |= match.fully_covers
end
end
Expand Down Expand Up @@ -1316,7 +1327,7 @@ function handle_const_call!(
j += 1
result = results[j]
any_fully_covered |= match.fully_covers
check_sparams = isa(only_method, Bool)
check_sparams = isa(only_method, Bool) # validate sparams if we know this meth has >1 match
if isa(result, ConcreteResult)
case = concrete_result_item(result, state)
push!(cases, InliningCase(result.mi.specTypes, case))
Expand All @@ -1340,14 +1351,18 @@ end

function handle_match!(
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase}, allow_abstract::Bool = false)
cases::Vector{InliningCase}, allow_abstract::Bool = false, check_sparams::Bool = false)
spec_types = match.spec_types
allow_abstract || isdispatchtuple(spec_types) || return false
# we may see duplicated dispatch signatures here when a signature gets widened
# during abstract interpretation: for the purpose of inlining, we can just skip
# processing this dispatch candidate
_any(case->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, flag, state)
if check_sparams
# we may see duplicated dispatch signatures here when a signature gets widened
# during abstract interpretation: for the purpose of inlining, we can just skip
# processing this dispatch candidate
_any(case->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, flag, state, true)
else
item = analyze_method!(match, argtypes, flag, state)
end
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
return true
Expand Down Expand Up @@ -1397,6 +1412,7 @@ function handle_const_opaque_closure_call!(
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
item = InliningTodo(result.result, sig.argtypes)
isdispatchtuple(item.mi.specTypes) || return
validate_sparams(item.mi.sparam_vals) || return
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
handle_single_case!(ir, idx, stmt, item, todo, state.params)
return nothing
Expand Down Expand Up @@ -1574,16 +1590,16 @@ function late_inline_special_case!(
end

function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any},
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue},
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, concrete_spvals::SimpleVector,
linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact)
compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS
compact.result[idx][:line] += linetable_offset
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck, compact, idx)
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, concrete_spvals, boundscheck, compact, idx)
end

function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, boundscheck::Symbol,
compact::IncrementalCompact, idx::Int)
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, concrete_spvals::SimpleVector,
boundscheck::Symbol, compact::IncrementalCompact, idx::Int)
if isa(val, Argument)
return arg_replacements[val.n]
end
Expand All @@ -1599,24 +1615,20 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
return ret
end
elseif head === :cfunction
if isa(spvals, SimpleVector)
@assert !isa(spsig, UnionAll) || !isempty(spvals)
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
e.args[4] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[4]::SimpleVector ]...)
end
@assert !isa(spsig, UnionAll) || !isempty(concrete_spvals)
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, concrete_spvals)
e.args[4] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, concrete_spvals)
for argt in e.args[4]::SimpleVector ]...)
elseif head === :foreigncall
if isa(spvals, SimpleVector)
@assert !isa(spsig, UnionAll) || !isempty(spvals)
for i = 1:length(e.args)
if i == 2
e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals)
elseif i == 3
e.args[3] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[3]::SimpleVector ]...)
end
@assert !isa(spsig, UnionAll) || !isempty(concrete_spvals)
for i = 1:length(e.args)
if i == 2
e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, concrete_spvals)
elseif i == 3
e.args[3] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, concrete_spvals)
for argt in e.args[3]::SimpleVector ]...)
end
end
elseif head === :boundscheck
Expand All @@ -1631,7 +1643,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
end
urs = userefs(val)
for op in urs
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck, compact, idx)
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, concrete_spvals, boundscheck, compact, idx)
end
return urs[]
end
14 changes: 9 additions & 5 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1492,15 +1492,15 @@ function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{A
end
values[i] = val
end
return FixedNode(values, needs_fixup)
return (values, needs_fixup)
end

function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_nodes::Bool)
if isa(stmt, PhiNode)
(;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
(node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
return FixedNode(PhiNode(stmt.edges, node), needs_fixup)
elseif isa(stmt, PhiCNode)
(;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
(node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
return FixedNode(PhiCNode(node), needs_fixup)
elseif isa(stmt, NewSSAValue)
if reify_new_nodes
Expand All @@ -1509,7 +1509,11 @@ function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_
return FixedNode(stmt, true)
end
elseif isa(stmt, OldSSAValue)
return FixedNode(compact.ssa_rename[stmt.id], false)
val = compact.ssa_rename[stmt.id]
if isa(val, SSAValue) && val.id <= length(compact.used_ssas)
compact.used_ssas[val.id] += 1
end
return FixedNode(val, false)
else
urs = userefs(stmt)
needs_fixup = false
Expand All @@ -1536,7 +1540,7 @@ function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_
end
end

function just_fixup!(compact::IncrementalCompact, new_new_nodes_offset::Union{Int, Nothing} = nothing, late_fixup_offset::Union{Int, Nothing}=nothing)
function just_fixup!(compact::IncrementalCompact, new_new_nodes_offset::Union{Int, Nothing} = nothing, late_fixup_offset::Union{Int, Nothing} = nothing)
if new_new_nodes_offset === late_fixup_offset === nothing # only do this appending in non_dce_finish!
resize!(compact.used_ssas, length(compact.result))
append!(compact.used_ssas, compact.new_new_used_ssas)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ function lift_svec_ref!(compact, idx, stmt)
val = stmt.args[4]
valT = argextype(val, compact)
(isa(valT, Const) && isa(valT.val, Int)) || return
valI = valT.val
valI = valT.val::Int
(1 <= valI) || return

if isa(vec, SimpleVector)
Expand Down
7 changes: 4 additions & 3 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1613,11 +1613,12 @@ JL_CALLABLE(jl_f__svec_ref)
JL_TYPECHK(_svec_ref, bool, b);
JL_TYPECHK(_svec_ref, simplevector, (jl_value_t*)s);
JL_TYPECHK(_svec_ref, long, i);
size_t len = jl_svec_len(s);
ssize_t idx = jl_unbox_long(i);
if (idx < 1 || idx > jl_svec_len(s)) {
jl_bounds_error_int(s, i);
if (idx < 1 || idx > len) {
jl_bounds_error_int((jl_value_t*)s, idx);
}
return jl_svec_ref(s, jl_unbox_long(i)-1);
return jl_svec_ref(s, idx-1);
}

static int equiv_field_types(jl_value_t *old, jl_value_t *ft)
Expand Down

0 comments on commit 0d6db0e

Please sign in to comment.