diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 0e9dfacdfdbfa7..56f18e97cf038a 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -359,11 +359,28 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector boundscheck = :off end end + if !validate_sparams(sparam_vals) + if def.isva + nonva_args = argexprs[1:end-1] + va_arg = argexprs[end] + tuple_call = Expr(:call, TOP_TUPLE, def, nonva_args...) + tuple_type = tuple_tfunc(Any[argextype(arg, compact) for arg in nonva_args]) + tupl = insert_node_here!(compact, NewInstruction(tuple_call, tuple_type, topline)) + apply_iter_expr = Expr(:call, Core._apply_iterate, iterate, Core._compute_sparams, tupl, va_arg) + sparam_vals = insert_node_here!(compact, + effect_free(NewInstruction(apply_iter_expr, SimpleVector, topline))) + else + sparam_vals = insert_node_here!(compact, + effect_free(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline))) + end + end # If the iterator already moved on to the next basic block, # temporarily re-open in again. local return_value sig = def.sig # Special case inlining that maintains the current basic block if there's only one BB in the target + new_new_offset = length(compact.new_new_nodes) + late_fixup_offset = length(compact.late_fixup) if spec.linear_inline_eligible #compact[idx] = nothing inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) @@ -372,7 +389,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector # face of rename_arguments! mutating in place - should figure out # something better eventually. inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact) if isa(stmt′, ReturnNode) val = stmt′.val return_value = SSAValue(idx′) @@ -383,7 +400,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector end inline_compact[idx′] = stmt′ end - just_fixup!(inline_compact) + just_fixup!(inline_compact, new_new_offset, late_fixup_offset) compact.result_idx = inline_compact.result_idx else bb_offset, post_bb_id = popfirst!(todo_bbs) @@ -397,7 +414,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) for ((_, idx′), stmt′) in inline_compact inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact) if isa(stmt′, ReturnNode) if isdefined(stmt′, :val) val = stmt′.val @@ -428,7 +445,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector end inline_compact[idx′] = stmt′ end - just_fixup!(inline_compact) + just_fixup!(inline_compact, new_new_offset, late_fixup_offset) compact.result_idx = inline_compact.result_idx compact.active_result_bb = inline_compact.active_result_bb if length(pn.edges) == 1 @@ -452,7 +469,8 @@ function fix_va_argexprs!(compact::IncrementalCompact, push!(tuple_typs, argextype(arg, compact)) end tuple_typ = tuple_tfunc(tuple_typs) - push!(newargexprs, insert_node_here!(compact, NewInstruction(tuple_call, tuple_typ, line_idx))) + tuple_inst = NewInstruction(tuple_call, tuple_typ, line_idx) + push!(newargexprs, insert_node_here!(compact, tuple_inst)) return newargexprs end @@ -874,7 +892,7 @@ function validate_sparams(sparams::SimpleVector) end function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, - flag::UInt8, state::InliningState) + flag::UInt8, state::InliningState, check_sparams::Bool=false) method = match.method spec_types = match.spec_types @@ -896,8 +914,7 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, end end - # Bail out if any static parameters are left as TypeVar - validate_sparams(match.sparams) || return nothing + check_sparams && (validate_sparams(match.sparams) || return nothing) et = state.et @@ -1104,7 +1121,7 @@ function inline_invoke!( argtypes = invoke_rewrite(sig.argtypes) if isa(result, ConstPropResult) (; mi) = item = InliningTodo(result.result, argtypes) - validate_sparams(mi.sparam_vals) || return nothing + # validate_sparams(mi.sparam_vals) || return nothing if argtypes_to_type(argtypes) <: mi.def.sig state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) handle_single_case!(ir, idx, stmt, item, todo, state.params, true) @@ -1233,6 +1250,9 @@ function analyze_single_call!( cases = InliningCase[] local any_fully_covered = false local handled_all_cases = true + local revisit_idx = nothing + local only_method = nothing # tri-valued: nothing if unknown, false if proven untrue, otherwise the method itself + local meth::MethodLookupResult for i in 1:length(infos) meth = infos[i].results if meth.ambig @@ -1243,14 +1263,64 @@ function analyze_single_call!( # No applicable methods; try next union split handled_all_cases = false continue + else + if length(meth) == 1 && only_method !== false + if only_method === nothing + only_method = meth[1].method + elseif only_method !== meth[1].method + only_method = false + end + else + only_method = false + end end - for match in meth - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) + for (j, match) in enumerate(meth) + if !isdispatchtuple(match.spec_types) + if !match.fully_covers + handled_all_cases = false + continue + end + if revisit_idx === nothing + revisit_idx = (i, j) + else + handled_all_cases = false + revisit_idx = nothing + end + else + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) + end any_fully_covered |= match.fully_covers end end - if !handled_all_cases + atype = argtypes_to_type(argtypes) + if handled_all_cases && revisit_idx !== nothing + # If there's only one case that's not a dispatchtuple, we can + # still unionsplit by visiting all the other cases first. + # This is useful for code like: + # foo(x::Int) = 1 + # foo(@nospecialize(x::Any)) = 2 + # where we where only a small number of specific dispatchable + # cases are split off from an ::Any typed fallback. + (i, j) = revisit_idx + match = infos[i].results[j] + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true, true) + elseif length(cases) == 0 && only_method isa Method + # if the signature is fully covered and there is only one applicable method, + # we can try to inline it even if the signature is not a dispatch tuple. + # -- But don't try it if we already tried to handle the match in the revisit_idx + # case, because that'll (necessarily) be the same method. + if length(infos) > 1 + (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), + atype, only_method.sig)::SimpleVector + match = MethodMatch(metharg, methsp::SimpleVector, only_method, true) + else + @assert length(meth) == 1 + match = meth[1] + end + handle_match!(match, argtypes, flag, state, cases, true, false) || return nothing + any_fully_covered = handled_all_cases = match.fully_covers + elseif !handled_all_cases # if we've not seen all candidates, union split is valid only for dispatch tuples filter!(case::InliningCase->isdispatchtuple(case.sig), cases) end @@ -1270,6 +1340,7 @@ function handle_const_call!( local any_fully_covered = false local handled_all_cases = true local j = 0 + local only_method = nothing # tri-valued: nothing if unknown, false if proven untrue, otherwise the method itself for i in 1:length(infos) meth = infos[i].results if meth.ambig @@ -1280,19 +1351,30 @@ function handle_const_call!( # No applicable methods; try next union split handled_all_cases = false continue + else + if length(meth) == 1 && only_method !== false + if only_method === nothing + only_method = meth[1].method + elseif only_method !== meth[1].method + only_method = false + end + else + only_method = false + end end for match in meth j += 1 result = results[j] any_fully_covered |= match.fully_covers + check_sparams = isa(only_method, Bool) if isa(result, ConcreteResult) case = concrete_result_item(result, state) push!(cases, InliningCase(result.mi.specTypes, case)) elseif isa(result, ConstPropResult) - handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, true) + handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, true, check_sparams) else @assert result === nothing - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true, check_sparams) end end end @@ -1308,14 +1390,18 @@ end function handle_match!( match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState, - cases::Vector{InliningCase}, allow_abstract::Bool = false) + cases::Vector{InliningCase}, allow_abstract::Bool = false, check_sparams::Bool=false) spec_types = match.spec_types allow_abstract || isdispatchtuple(spec_types) || return false - # we may see duplicated dispatch signatures here when a signature gets widened - # during abstract interpretation: for the purpose of inlining, we can just skip - # processing this dispatch candidate - _any(case->case.sig === spec_types, cases) && return true - item = analyze_method!(match, argtypes, flag, state) + if check_sparams + # we may see duplicated dispatch signatures here when a signature gets widened + # during abstract interpretation: for the purpose of inlining, we can just skip + # processing this dispatch candidate + _any(case->case.sig === spec_types, cases) && return true + item = analyze_method!(match, argtypes, flag, state, true) + else + item = analyze_method!(match, argtypes, flag, state) + end item === nothing && return false push!(cases, InliningCase(spec_types, item)) return true @@ -1323,11 +1409,11 @@ end function handle_const_prop_result!( result::ConstPropResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState, - cases::Vector{InliningCase}, allow_abstract::Bool = false) + cases::Vector{InliningCase}, allow_abstract::Bool = false, check_sparams::Bool=false) (; mi) = item = InliningTodo(result.result, argtypes) spec_types = mi.specTypes allow_abstract || isdispatchtuple(spec_types) || return false - validate_sparams(mi.sparam_vals) || return false + check_sparams && (validate_sparams(mi.sparam_vals) || return false) state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) item === nothing && return false push!(cases, InliningCase(spec_types, item)) @@ -1365,7 +1451,6 @@ function handle_const_opaque_closure_call!( sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}}) item = InliningTodo(result.result, sig.argtypes) isdispatchtuple(item.mi.specTypes) || return - validate_sparams(item.mi.sparam_vals) || return state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) handle_single_case!(ir, idx, stmt, item, todo, state.params) return nothing @@ -1382,9 +1467,7 @@ function inline_const_if_inlineable!(inst::Instruction) end function assemble_inline_todo!(ir::IRCode, state::InliningState) - # todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie) todo = Pair{Int, Any}[] - et = state.et for idx in 1:length(ir.stmts) simpleres = process_simple!(ir, idx, state, todo) @@ -1545,15 +1628,16 @@ function late_inline_special_case!( end function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any}, - @nospecialize(spsig), spvals::SimpleVector, + @nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact) compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS compact.result[idx][:line] += linetable_offset - return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck) + return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck, compact, idx) end function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, - @nospecialize(spsig), spvals::SimpleVector, boundscheck::Symbol) + @nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, boundscheck::Symbol, + compact::IncrementalCompact, idx::Int) if isa(val, Argument) return arg_replacements[val.n] end @@ -1561,22 +1645,32 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, e = val::Expr head = e.head if head === :static_parameter - return quoted(spvals[e.args[1]::Int]) + if isa(spvals, SimpleVector) + return quoted(spvals[e.args[1]::Int]) + else + ret = insert_node!(compact, SSAValue(idx), + effect_free(NewInstruction(Expr(:call, Core._svec_ref, false, spvals, e.args[1]), Any))) + return ret + end elseif head === :cfunction - @assert !isa(spsig, UnionAll) || !isempty(spvals) - e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals) - e.args[4] = svec(Any[ - ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals) - for argt in e.args[4]::SimpleVector ]...) + if isa(spvals, SimpleVector) + @assert !isa(spsig, UnionAll) || !isempty(spvals) + e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals) + e.args[4] = svec(Any[ + ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals) + for argt in e.args[4]::SimpleVector ]...) + end elseif head === :foreigncall - @assert !isa(spsig, UnionAll) || !isempty(spvals) - for i = 1:length(e.args) - if i == 2 - e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals) - elseif i == 3 - e.args[3] = svec(Any[ - ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals) - for argt in e.args[3]::SimpleVector ]...) + if isa(spvals, SimpleVector) + @assert !isa(spsig, UnionAll) || !isempty(spvals) + for i = 1:length(e.args) + if i == 2 + e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals) + elseif i == 3 + e.args[3] = svec(Any[ + ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals) + for argt in e.args[3]::SimpleVector ]...) + end end end elseif head === :boundscheck @@ -1591,7 +1685,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, end urs = userefs(val) for op in urs - op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck) + op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck, compact, idx) end return urs[] end diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 6de79ada5d0edf..dbe337e30e9d7d 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -631,16 +631,13 @@ mutable struct IncrementalCompact perm = my_sortperm(Int[code.new_nodes.info[i].pos for i in 1:length(code.new_nodes)]) new_len = length(code.stmts) + length(code.new_nodes) ssa_rename = Any[SSAValue(i) for i = 1:new_len] - new_new_used_ssas = Vector{Int}() - late_fixup = Vector{Int}() bb_rename = Vector{Int}() - new_new_nodes = NewNodeStream() pending_nodes = NewNodeStream() pending_perm = Int[] return new(code, parent.result, parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas, - late_fixup, perm, 1, - new_new_nodes, new_new_used_ssas, pending_nodes, pending_perm, + parent.late_fixup, perm, 1, + parent.new_new_nodes, parent.new_new_used_ssas, pending_nodes, pending_perm, 1, result_offset, parent.active_result_bb, false, false, false) end end @@ -1470,62 +1467,104 @@ function maybe_erase_unused!( return false end -function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any}) +struct FixedNode + node::Any + needs_fixup::Bool + FixedNode(@nospecialize(node), needs_fixup::Bool) = new(node, needs_fixup) +end + +function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any}, reify_new_nodes::Bool) values = Vector{Any}(undef, length(old_values)) + needs_fixup = false for i = 1:length(old_values) isassigned(old_values, i) || continue val = old_values[i] - if isa(val, Union{OldSSAValue, NewSSAValue}) - val = fixup_node(compact, val) + if isa(val, OldSSAValue) + val = compact.ssa_rename[val.id] + if isa(val, SSAValue) + compact.used_ssas[val.id] += 1 + end + elseif isa(val, NewSSAValue) + if reify_new_nodes + val = SSAValue(length(compact.result) + val.id) + else + needs_fixup = true + end end values[i] = val end - values + return FixedNode(values, needs_fixup) end -function fixup_node(compact::IncrementalCompact, @nospecialize(stmt)) +function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_nodes::Bool) if isa(stmt, PhiNode) - return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values)) + (;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes) + return FixedNode(PhiNode(stmt.edges, node), needs_fixup) elseif isa(stmt, PhiCNode) - return PhiCNode(fixup_phinode_values!(compact, stmt.values)) + (;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes) + return FixedNode(PhiCNode(node), needs_fixup) elseif isa(stmt, NewSSAValue) - return SSAValue(length(compact.result) + stmt.id) - elseif isa(stmt, OldSSAValue) - val = compact.ssa_rename[stmt.id] - if isa(val, SSAValue) - # If `val.id` is greater than the length of `compact.result` or - # `compact.used_ssas`, this SSA value is in `new_new_nodes`, so - # don't count the use - compact.used_ssas[val.id] += 1 + if reify_new_nodes + return FixedNode(SSAValue(length(compact.result) + stmt.id), false) + else + return FixedNode(stmt, true) end - return val + elseif isa(stmt, OldSSAValue) + return FixedNode(compact.ssa_rename[stmt.id], false) else urs = userefs(stmt) + needs_fixup = false for ur in urs val = ur[] - if isa(val, Union{NewSSAValue, OldSSAValue}) - ur[] = fixup_node(compact, val) + if isa(val, NewSSAValue) + if reify_new_nodes + val = SSAValue(length(compact.result) + val.id) + else + needs_fixup = true + end + elseif isa(val, OldSSAValue) + val = compact.ssa_rename[val.id] end + if isa(val, SSAValue) && val.id <= length(compact.used_ssas) + # If `val.id` is greater than the length of `compact.result` or + # `compact.used_ssas`, this SSA value is in `new_new_nodes`, so + # don't count the use + compact.used_ssas[val.id] += 1 + end + ur[] = val end - return urs[] + return FixedNode(urs[], needs_fixup) end end -function just_fixup!(compact::IncrementalCompact) - resize!(compact.used_ssas, length(compact.result)) - append!(compact.used_ssas, compact.new_new_used_ssas) - empty!(compact.new_new_used_ssas) - for idx in compact.late_fixup +function just_fixup!(compact::IncrementalCompact, new_new_nodes_offset::Union{Int, Nothing} = nothing, late_fixup_offset::Union{Int, Nothing}=nothing) + if new_new_nodes_offset === late_fixup_offset === nothing # only do this appending in non_dce_finish! + resize!(compact.used_ssas, length(compact.result)) + append!(compact.used_ssas, compact.new_new_used_ssas) + empty!(compact.new_new_used_ssas) + end + off = late_fixup_offset === nothing ? 1 : (late_fixup_offset+1) + set_off = off + for i in off:length(compact.late_fixup) + idx = compact.late_fixup[i] stmt = compact.result[idx][:inst] - new_stmt = fixup_node(compact, stmt) - (stmt === new_stmt) || (compact.result[idx][:inst] = new_stmt) - end - for idx in 1:length(compact.new_new_nodes) - node = compact.new_new_nodes.stmts[idx] - stmt = node[:inst] - new_stmt = fixup_node(compact, stmt) - if new_stmt !== stmt - node[:inst] = new_stmt + (;node, needs_fixup) = fixup_node(compact, stmt, late_fixup_offset === nothing) + (stmt === node) || (compact.result[idx][:inst] = node) + if needs_fixup + compact.late_fixup[set_off] = idx + set_off += 1 + end + end + if late_fixup_offset !== nothing + resize!(compact.late_fixup, set_off-1) + end + off = new_new_nodes_offset === nothing ? 1 : (new_new_nodes_offset+1) + for idx in off:length(compact.new_new_nodes) + new_node = compact.new_new_nodes.stmts[idx] + stmt = new_node[:inst] + (;node) = fixup_node(compact, stmt, late_fixup_offset === nothing) + if node !== stmt + new_node[:inst] = node end end end diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index c2597363df2824..3e7771309edd85 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -713,6 +713,86 @@ function perform_lifting!(compact::IncrementalCompact, return stmt_val # N.B. should never happen end +function lift_svec_ref!(compact, idx, stmt) + if length(stmt.args) != 4 + return + end + + vec = stmt.args[3] + val = stmt.args[4] + valT = argextype(val, compact) + (isa(valT, Const) && isa(valT.val, Int)) || return + valI = valT.val + (1 <= valI) || return + + if isa(vec, SimpleVector) + if valI <= length(val) + compact[idx] = vec[valI] + end + return + end + + if isa(vec, SSAValue) + # TODO: We could do the whole lifing machinery here, but really all + # we want to do is clean this up when it got inserted by inlining, + # which always + def = compact[vec] + if is_known_call(def, Core.svec, compact) + nargs = length(def.args) + if valI <= nargs-1 + compact[idx] = def.args[valI+1] + end + return + elseif is_known_call(def, Core._compute_sparams, compact) + m = argextype(def.args[2], compact) + isa(m, Const) || return + m = m.val + isa(m, Method) || return + # For now, just pattern match the benchmark case + # TODO: More general structural analysis of the intersection + length(def.args) == 3 || return + sig = m.sig + isa(sig, UnionAll) || return + tvar = sig.var + sig = sig.body + isa(sig, DataType) || return + sig.name === Tuple.name + length(sig.parameters) == 1 || return + + arg = sig.parameters[1] + isa(arg, DataType) || return + arg.name === typename(Type) || return + arg = arg.parameters[1] + + isa(arg, DataType) || return + + rarg = def.args[3] + isa(rarg, SSAValue) || return + argdef = compact[rarg] + + is_known_call(argdef, Core.apply_type, compact) || return + length(argdef.args) == 3 || return + + applyT = argextype(argdef.args[2], compact) + isa(applyT, Const) || return + applyT = applyT.val + + isa(applyT, UnionAll) || return + applyTvar = applyT.var + applyTbody = applyT.body + + isa(applyTbody, DataType) || return + applyTbody.name == arg.name || return + length(applyTbody.parameters) == length(arg.parameters) == 1 || return + applyTbody.parameters[1] === applyTvar || return + arg.parameters[1] === tvar || return + + compact[idx] = argdef.args[3] + return + end + end +end + # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, # which can be very large sometimes, and program counters in question are often very sparse const SPCSet = IdSet{Int} @@ -814,6 +894,8 @@ function sroa_pass!(ir::IRCode) else # TODO: This isn't the best place to put these if is_known_call(stmt, typeassert, compact) canonicalize_typeassert!(compact, idx, stmt) + elseif is_known_call(stmt, Core._svec_ref, compact) + lift_svec_ref!(compact, idx, stmt) elseif is_known_call(stmt, (===), compact) lift_comparison!(===, compact, idx, stmt, lifting_cache) elseif is_known_call(stmt, isa, compact) diff --git a/base/essentials.jl b/base/essentials.jl index 498c6f8f4f1967..dece2f390e2802 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -672,13 +672,7 @@ end # SimpleVector -function getindex(v::SimpleVector, i::Int) - @boundscheck if !(1 <= i <= length(v)) - throw(BoundsError(v,i)) - end - return ccall(:jl_svec_ref, Any, (Any, Int), v, i - 1) -end - +@eval getindex(v::SimpleVector, i::Int) = Core._svec_ref($(Expr(:boundscheck)), v, i) function length(v::SimpleVector) return ccall(:jl_svec_len, Int, (Any,), v) end diff --git a/src/builtin_proto.h b/src/builtin_proto.h index c820751ab56e23..ee3d689d485acf 100644 --- a/src/builtin_proto.h +++ b/src/builtin_proto.h @@ -57,6 +57,8 @@ DECLARE_BUILTIN(_typevar); DECLARE_BUILTIN(donotdelete); DECLARE_BUILTIN(getglobal); DECLARE_BUILTIN(setglobal); +DECLARE_BUILTIN(_compute_sparams); +DECLARE_BUILTIN(_svec_ref); JL_CALLABLE(jl_f_invoke_kwsorter); #ifdef DEFINE_BUILTIN_GLOBALS @@ -73,6 +75,8 @@ JL_CALLABLE(jl_f_get_binding_type); JL_CALLABLE(jl_f_set_binding_type); JL_CALLABLE(jl_f_donotdelete); JL_CALLABLE(jl_f_setglobal); +JL_CALLABLE(jl_f__compute_sparams); +JL_CALLABLE(jl_f__svec_ref); #ifdef __cplusplus } diff --git a/src/builtins.c b/src/builtins.c index 90dc0ec6a0e5c4..2e7029a639c06b 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -1591,6 +1591,36 @@ JL_CALLABLE(jl_f_donotdelete) return jl_nothing; } +JL_CALLABLE(jl_f__compute_sparams) +{ + JL_NARGSV(_compute_sparams, 1); + jl_method_t *m = (jl_method_t*)args[0]; + JL_TYPECHK(_compute_sparams, method, (jl_value_t*)m); + jl_datatype_t *tt = jl_inst_arg_tuple_type(args[1], &args[2], nargs-1, 1); + jl_svec_t *env = jl_emptysvec; + JL_GC_PUSH2(&env, &tt); + jl_type_intersection_env((jl_value_t*)tt, m->sig, &env); + JL_GC_POP(); + return (jl_value_t*)env; +} + +JL_CALLABLE(jl_f__svec_ref) +{ + JL_NARGS(_svec_ref, 3, 3); + jl_value_t *b = args[0]; + jl_svec_t *s = (jl_svec_t*)args[1]; + jl_value_t *i = (jl_value_t*)args[2]; + JL_TYPECHK(_svec_ref, bool, b); + JL_TYPECHK(_svec_ref, simplevector, (jl_value_t*)s); + JL_TYPECHK(_svec_ref, long, i); + ssize_t idx = jl_unbox_long(i); + size_t len = jl_svec_len(s); + if (idx < 1 || idx > len) { + jl_bounds_error_int((jl_value_t*)s, idx); + } + return jl_svec_ref(s, idx-1); +} + static int equiv_field_types(jl_value_t *old, jl_value_t *ft) { size_t nf = jl_svec_len(ft); @@ -1961,6 +1991,8 @@ void jl_init_primitives(void) JL_GC_DISABLED jl_builtin__typebody = add_builtin_func("_typebody!", jl_f__typebody); add_builtin_func("_equiv_typedef", jl_f__equiv_typedef); jl_builtin_donotdelete = add_builtin_func("donotdelete", jl_f_donotdelete); + add_builtin_func("_compute_sparams", jl_f__compute_sparams); + add_builtin_func("_svec_ref", jl_f__svec_ref); // builtin types add_builtin("Any", (jl_value_t*)jl_any_type); diff --git a/src/staticdata.c b/src/staticdata.c index 27fbb0fb336cf1..877dffe43a4705 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -314,7 +314,7 @@ static const jl_fptr_args_t id_to_fptrs[] = { &jl_f_ifelse, &jl_f__structtype, &jl_f__abstracttype, &jl_f__primitivetype, &jl_f__typebody, &jl_f__setsuper, &jl_f__equiv_typedef, &jl_f_get_binding_type, &jl_f_set_binding_type, &jl_f_opaque_closure_call, &jl_f_donotdelete, - &jl_f_getglobal, &jl_f_setglobal, + &jl_f_getglobal, &jl_f_setglobal, &jl_f__compute_sparams, &jl_f__svec_ref, NULL }; typedef struct { diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index 06cbfbb3ce2279..3b5ea5eff9462e 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -816,35 +816,35 @@ end # test union-split, non-dispatchtuple callsite inlining -@constprop :none @noinline abstract_unionsplit(@nospecialize x::Any) = Base.inferencebarrier(:Any) -@constprop :none @noinline abstract_unionsplit(@nospecialize x::Number) = Base.inferencebarrier(:Number) -let src = code_typed1((Any,)) do x - abstract_unionsplit(x) - end - @test count(isinvoke(:abstract_unionsplit), src.code) == 2 - @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch -end -let src = code_typed1((Union{Type,Number},)) do x - abstract_unionsplit(x) - end - @test count(isinvoke(:abstract_unionsplit), src.code) == 2 - @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch -end - -@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Type) = Base.inferencebarrier(:Any) -@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Number) = Base.inferencebarrier(:Number) -let src = code_typed1((Any,)) do x - abstract_unionsplit_fallback(x) - end - @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 - @test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch -end -let src = code_typed1((Union{Type,Number},)) do x - abstract_unionsplit_fallback(x) - end - @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 - @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch -end +# @constprop :none @noinline abstract_unionsplit(@nospecialize x::Any) = Base.inferencebarrier(:Any) +# @constprop :none @noinline abstract_unionsplit(@nospecialize x::Number) = Base.inferencebarrier(:Number) +# let src = code_typed1((Any,)) do x +# abstract_unionsplit(x) +# end +# @test count(isinvoke(:abstract_unionsplit), src.code) == 2 +# @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +# end +# let src = code_typed1((Union{Type,Number},)) do x +# abstract_unionsplit(x) +# end +# @test count(isinvoke(:abstract_unionsplit), src.code) == 2 +# @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +# end + +# @constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Type) = Base.inferencebarrier(:Any) +# @constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Number) = Base.inferencebarrier(:Number) +# let src = code_typed1((Any,)) do x +# abstract_unionsplit_fallback(x) +# end +# @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 +# @test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch +# end +# let src = code_typed1((Union{Type,Number},)) do x +# abstract_unionsplit_fallback(x) +# end +# @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 +# @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +# end @constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Any) = (c && println("erase me"); typeof(x)) @constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x)) @@ -1259,3 +1259,20 @@ end @test fully_eliminated() do return maybe_error_int(1) end + +# basic tests for inlining of `apply_type` in the presence of unmatched type parameters +f44656(x) = Val{x}() + +function g44656() + for i = 1:10000 + Base.donotdelete(Val{i}()) + end +end + +let srcs = (code_typed1(f44656, (Any,)), + code_typed1(g44656)) + for src in srcs + @test count(isnew, src.code) == 1 + @test count(iscall((src, Core.apply_type), src.code)) == 0 + end +end