diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 4f6c0a6d6c243..84864404f008e 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -376,11 +376,28 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector boundscheck = :off end end + if !validate_sparams(sparam_vals) + if def.isva + nonva_args = argexprs[1:end-1] + va_arg = argexprs[end] + tuple_call = Expr(:call, TOP_TUPLE, def, nonva_args...) + tuple_type = tuple_tfunc(Any[argextype(arg, compact) for arg in nonva_args]) + tupl = insert_node_here!(compact, NewInstruction(tuple_call, tuple_type, topline)) + apply_iter_expr = Expr(:call, Core._apply_iterate, iterate, Core._compute_sparams, tupl, va_arg) + sparam_vals = insert_node_here!(compact, + effect_free(NewInstruction(apply_iter_expr, SimpleVector, topline))) + else + sparam_vals = insert_node_here!(compact, + effect_free(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline))) + end + end # If the iterator already moved on to the next basic block, # temporarily re-open in again. local return_value sig = def.sig # Special case inlining that maintains the current basic block if there's only one BB in the target + new_new_offset = length(compact.new_new_nodes) + late_fixup_offset = length(compact.late_fixup) if spec.linear_inline_eligible #compact[idx] = nothing inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) @@ -389,7 +406,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector # face of rename_arguments! mutating in place - should figure out # something better eventually. inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact) if isa(stmt′, ReturnNode) val = stmt′.val return_value = SSAValue(idx′) @@ -402,7 +419,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector end inline_compact[idx′] = stmt′ end - just_fixup!(inline_compact) + just_fixup!(inline_compact, new_new_offset, late_fixup_offset) compact.result_idx = inline_compact.result_idx else bb_offset, post_bb_id = popfirst!(todo_bbs) @@ -416,7 +433,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) for ((_, idx′), stmt′) in inline_compact inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact) if isa(stmt′, ReturnNode) if isdefined(stmt′, :val) val = stmt′.val @@ -436,7 +453,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector end inline_compact[idx′] = stmt′ end - just_fixup!(inline_compact) + just_fixup!(inline_compact, new_new_offset, late_fixup_offset) compact.result_idx = inline_compact.result_idx compact.active_result_bb = inline_compact.active_result_bb if length(pn.edges) == 1 @@ -460,7 +477,8 @@ function fix_va_argexprs!(compact::IncrementalCompact, push!(tuple_typs, argextype(arg, compact)) end tuple_typ = tuple_tfunc(tuple_typs) - push!(newargexprs, insert_node_here!(compact, NewInstruction(tuple_call, tuple_typ, line_idx))) + tuple_inst = NewInstruction(tuple_call, tuple_typ, line_idx) + push!(newargexprs, insert_node_here!(compact, tuple_inst)) return newargexprs end @@ -875,8 +893,26 @@ function validate_sparams(sparams::SimpleVector) return true end +function may_have_fcalls(m::Method) + may_have_fcall = true + if isdefined(m, :source) + src = m.source + isa(src, Vector{UInt8}) && (src = uncompressed_ir(m)) + if isa(src, CodeInfo) + may_have_fcall = src.has_fcall + end + end + return may_have_fcall +end + +function can_inline_typevars(m::MethodMatch, argtypes::Vector{Any}) + may_have_fcalls(m.method) && return false + any(@nospecialize(x) -> x isa UnionAll, argtypes[2:end]) && return false + return true +end + function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, invokesig, - flag::UInt8, state::InliningState) + flag::UInt8, state::InliningState, allow_typevars::Bool = false) method = match.method spec_types = match.spec_types @@ -898,8 +934,9 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, invokesig, end end - # Bail out if any static parameters are left as TypeVar - validate_sparams(match.sparams) || return nothing + if !validate_sparams(match.sparams) + (allow_typevars && can_inline_typevars(match, argtypes)) || return nothing + end et = state.et @@ -1231,6 +1268,9 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo}, flag::UInt8, sig::Signature, state::InliningState) argtypes = sig.argtypes cases = InliningCase[] + local only_method = nothing + local meth::MethodLookupResult + local revisit_idx = nothing local any_fully_covered = false local handled_all_cases = true for i in 1:length(infos) @@ -1243,14 +1283,58 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo}, # No applicable methods; try next union split handled_all_cases = false continue + else + if length(meth) == 1 && only_method !== false + if only_method === nothing + only_method = meth[1].method + elseif only_method !== meth[1].method + only_method = false + end + else + only_method = false + end end - for match in meth - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true) + for (j, match) in enumerate(meth) any_fully_covered |= match.fully_covers + if !validate_sparams(match.sparams) + if !match.fully_covers + handled_all_cases = false + continue + end + if revisit_idx === nothing + revisit_idx = (i, j) + else + handled_all_cases = false + revisit_idx = nothing + end + else + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false) + end end end - if !handled_all_cases + if handled_all_cases && revisit_idx !== nothing + # we handled everything except one match with unmatched sparams, + # so try to handle it by bypassing validate_sparams + (i, j) = revisit_idx + match = infos[i].results[j] + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#true) + elseif length(cases) == 0 && only_method isa Method + # if the signature is fully covered and there is only one applicable method, + # we can try to inline it even in the prescence of unmatched sparams + # -- But don't try it if we already tried to handle the match in the revisit_idx + # case, because that'll (necessarily) be the same method. + if length(infos) > 1 + atype = argtypes_to_type(argtypes) + (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), atype, only_method.sig)::SimpleVector + match = MethodMatch(metharg, methsp::SimpleVector, only_method, true) + else + @assert length(meth) == 1 + match = meth[1] + end + handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#true) || return nothing + any_fully_covered = handled_all_cases = match.fully_covers + elseif !handled_all_cases # if we've not seen all candidates, union split is valid only for dispatch tuples filter!(case::InliningCase->isdispatchtuple(case.sig), cases) end @@ -1286,10 +1370,10 @@ function compute_inlining_cases(info::ConstCallInfo, case = concrete_result_item(result, state) push!(cases, InliningCase(result.mi.specTypes, case)) elseif isa(result, ConstPropResult) - handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, true) + handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, #=allow_abstract=#true) else @assert result === nothing - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true) + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false) end end end @@ -1324,14 +1408,14 @@ end function handle_match!( match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState, - cases::Vector{InliningCase}, allow_abstract::Bool = false) + cases::Vector{InliningCase}, allow_abstract::Bool, allow_typevars::Bool) spec_types = match.spec_types allow_abstract || isdispatchtuple(spec_types) || return false - # we may see duplicated dispatch signatures here when a signature gets widened + # We may see duplicated dispatch signatures here when a signature gets widened # during abstract interpretation: for the purpose of inlining, we can just skip - # processing this dispatch candidate - _any(case->case.sig === spec_types, cases) && return true - item = analyze_method!(match, argtypes, nothing, flag, state) + # processing this dispatch candidate (unless unmatched type parameters are present) + !allow_typevars && _any(case->case.sig === spec_types, cases) && return true + item = analyze_method!(match, argtypes, nothing, flag, state, allow_typevars) item === nothing && return false push!(cases, InliningCase(spec_types, item)) return true @@ -1339,7 +1423,7 @@ end function handle_const_prop_result!( result::ConstPropResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState, - cases::Vector{InliningCase}, allow_abstract::Bool = false) + cases::Vector{InliningCase}, allow_abstract::Bool) (; mi) = item = InliningTodo(result.result, argtypes) spec_types = mi.specTypes allow_abstract || isdispatchtuple(spec_types) || return false @@ -1624,15 +1708,16 @@ function late_inline_special_case!( end function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any}, - @nospecialize(spsig), spvals::SimpleVector, + @nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact) compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS compact.result[idx][:line] += linetable_offset - return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck) + return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck, compact, idx) end function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, - @nospecialize(spsig), spvals::SimpleVector, boundscheck::Symbol) + @nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, + boundscheck::Symbol, compact::IncrementalCompact, idx::Int) if isa(val, Argument) return arg_replacements[val.n] end @@ -1640,14 +1725,20 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, e = val::Expr head = e.head if head === :static_parameter - return quoted(spvals[e.args[1]::Int]) - elseif head === :cfunction + if isa(spvals, SimpleVector) + return quoted(spvals[e.args[1]::Int]) + else + ret = insert_node!(compact, SSAValue(idx), + effect_free(NewInstruction(Expr(:call, Core._svec_ref, false, spvals, e.args[1]), Any))) + return ret + end + elseif head === :cfunction && isa(spvals, SimpleVector) @assert !isa(spsig, UnionAll) || !isempty(spvals) e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals) e.args[4] = svec(Any[ ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals) for argt in e.args[4]::SimpleVector ]...) - elseif head === :foreigncall + elseif head === :foreigncall && isa(spvals, SimpleVector) @assert !isa(spsig, UnionAll) || !isempty(spvals) for i = 1:length(e.args) if i == 2 @@ -1671,7 +1762,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, isa(val, Union{SSAValue, NewSSAValue}) && return val # avoid infinite loop urs = userefs(val) for op in urs - op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck) + op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck, compact, idx) end return urs[] end diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index b2a6ee4d65586..aee3f8f1ff6fe 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -618,16 +618,13 @@ mutable struct IncrementalCompact perm = my_sortperm(Int[code.new_nodes.info[i].pos for i in 1:length(code.new_nodes)]) new_len = length(code.stmts) + length(code.new_nodes) ssa_rename = Any[SSAValue(i) for i = 1:new_len] - new_new_used_ssas = Vector{Int}() - late_fixup = Vector{Int}() bb_rename = Vector{Int}() - new_new_nodes = NewNodeStream() pending_nodes = NewNodeStream() pending_perm = Int[] return new(code, parent.result, parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas, - late_fixup, perm, 1, - new_new_nodes, new_new_used_ssas, pending_nodes, pending_perm, + parent.late_fixup, perm, 1, + parent.new_new_nodes, parent.new_new_used_ssas, pending_nodes, pending_perm, 1, result_offset, parent.active_result_bb, false, false, false) end end @@ -1490,63 +1487,89 @@ function maybe_erase_unused!( return false end -function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any}) +struct FixedNode + node::Any + needs_fixup::Bool + FixedNode(@nospecialize(node), needs_fixup::Bool) = new(node, needs_fixup) +end + +function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any}, reify_new_nodes::Bool) values = Vector{Any}(undef, length(old_values)) + fixup = false for i = 1:length(old_values) isassigned(old_values, i) || continue - val = old_values[i] - if isa(val, Union{OldSSAValue, NewSSAValue}) - val = fixup_node(compact, val) - end - values[i] = val + (; node, needs_fixup) = fixup_node(compact, old_values[i], reify_new_nodes) + fixup |= needs_fixup + values[i] = node end - values + return (values, fixup) end -function fixup_node(compact::IncrementalCompact, @nospecialize(stmt)) + +function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_nodes::Bool) if isa(stmt, PhiNode) - return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values)) + (node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes) + return FixedNode(PhiNode(stmt.edges, node), needs_fixup) elseif isa(stmt, PhiCNode) - return PhiCNode(fixup_phinode_values!(compact, stmt.values)) + (node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes) + return FixedNode(PhiCNode(node), needs_fixup) elseif isa(stmt, NewSSAValue) @assert stmt.id < 0 - return SSAValue(length(compact.result) - stmt.id) + if reify_new_nodes + val = SSAValue(length(compact.result) - stmt.id) + return FixedNode(val, false) + else + return FixedNode(stmt, true) + end elseif isa(stmt, OldSSAValue) val = compact.ssa_rename[stmt.id] if isa(val, SSAValue) - # If `val.id` is greater than the length of `compact.result` or - # `compact.used_ssas`, this SSA value is in `new_new_nodes`, so - # don't count the use compact.used_ssas[val.id] += 1 end - return val + return FixedNode(val, false) else urs = userefs(stmt) + fixup = false for ur in urs val = ur[] if isa(val, Union{NewSSAValue, OldSSAValue}) - ur[] = fixup_node(compact, val) + (;node, needs_fixup) = fixup_node(compact, val, reify_new_nodes) + fixup |= needs_fixup + ur[] = node end end - return urs[] + return FixedNode(urs[], fixup) end end -function just_fixup!(compact::IncrementalCompact) - resize!(compact.used_ssas, length(compact.result)) - append!(compact.used_ssas, compact.new_new_used_ssas) - empty!(compact.new_new_used_ssas) - for idx in compact.late_fixup +function just_fixup!(compact::IncrementalCompact, new_new_nodes_offset::Union{Int, Nothing} = nothing, late_fixup_offset::Union{Int, Nothing} = nothing) + if new_new_nodes_offset === late_fixup_offset === nothing # only do this appending in non_dce_finish! + resize!(compact.used_ssas, length(compact.result)) + append!(compact.used_ssas, compact.new_new_used_ssas) + empty!(compact.new_new_used_ssas) + end + off = late_fixup_offset === nothing ? 1 : (late_fixup_offset+1) + set_off = off + for i in off:length(compact.late_fixup) + idx = compact.late_fixup[i] stmt = compact.result[idx][:inst] - new_stmt = fixup_node(compact, stmt) - (stmt === new_stmt) || (compact.result[idx][:inst] = new_stmt) - end - for idx in 1:length(compact.new_new_nodes) - node = compact.new_new_nodes.stmts[idx] - stmt = node[:inst] - new_stmt = fixup_node(compact, stmt) - if new_stmt !== stmt - node[:inst] = new_stmt + (;node, needs_fixup) = fixup_node(compact, stmt, late_fixup_offset === nothing) + (stmt === node) || (compact.result[idx][:inst] = node) + if needs_fixup + compact.late_fixup[set_off] = idx + set_off += 1 + end + end + if late_fixup_offset !== nothing + resize!(compact.late_fixup, set_off-1) + end + off = new_new_nodes_offset === nothing ? 1 : (new_new_nodes_offset+1) + for idx in off:length(compact.new_new_nodes) + new_node = compact.new_new_nodes.stmts[idx] + stmt = new_node[:inst] + (;node) = fixup_node(compact, stmt, late_fixup_offset === nothing) + if node !== stmt + new_node[:inst] = node end end end diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 594b77b38654a..659f51b754a2c 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -720,6 +720,97 @@ function perform_lifting!(compact::IncrementalCompact, return stmt_val # N.B. should never happen end +function lift_svec_ref!(compact::IncrementalCompact, idx::Int, stmt::Expr) + if length(stmt.args) != 4 + return + end + + vec = stmt.args[3] + val = stmt.args[4] + valT = argextype(val, compact) + (isa(valT, Const) && isa(valT.val, Int)) || return + valI = valT.val::Int + (1 <= valI) || return + + if isa(vec, SimpleVector) + if valI <= length(val) + compact[idx] = vec[valI] + end + return + end + + if isa(vec, SSAValue) + def = compact[vec][:inst] + if is_known_call(def, Core.svec, compact) + nargs = length(def.args) + if valI <= nargs-1 + compact[idx] = def.args[valI+1] + end + return + elseif is_known_call(def, Core._compute_sparams, compact) + res = _lift_svec_ref(def, compact) + if res !== nothing + compact[idx] = res + end + return + end + end +end + +function _lift_svec_ref(def::Expr, compact::IncrementalCompact) + # TODO: We could do the whole lifing machinery here, but really all + # we want to do is clean this up when it got inserted by inlining, + # which always targets simple `svec` call or `_compute_sparams`, + # so this specialized lifting would be enough + m = argextype(def.args[2], compact) + isa(m, Const) || return nothing + m = m.val + isa(m, Method) || return nothing + # TODO: More general structural analysis of the intersection + length(def.args) >= 3 || return nothing + sig = m.sig + isa(sig, UnionAll) || return nothing + tvar = sig.var + sig = sig.body + isa(sig, DataType) || return nothing + sig.name === Tuple.name || return nothing + length(sig.parameters) >= 1 || return nothing + + i = findfirst(j->has_typevar(sig.parameters[j], tvar), 1:length(sig.parameters)) + i === nothing && return nothing + _any(j->has_typevar(sig.parameters[j], tvar), i+1:length(sig.parameters)) && return nothing + + arg = sig.parameters[i] + isa(arg, DataType) || return nothing + + rarg = def.args[2 + i] + isa(rarg, SSAValue) || return nothing + argdef = compact[rarg][:inst] + if isexpr(argdef, :new) + rarg = argdef.args[1] + isa(rarg, SSAValue) || return nothing + argdef = compact[rarg][:inst] + end + + is_known_call(argdef, Core.apply_type, compact) || return nothing + length(argdef.args) == 3 || return nothing + + applyT = argextype(argdef.args[2], compact) + isa(applyT, Const) || return nothing + applyT = applyT.val + + isa(applyT, UnionAll) || return nothing + applyTvar = applyT.var + applyTbody = applyT.body + + isa(applyTbody, DataType) || return nothing + applyTbody.name == arg.name || return nothing + length(applyTbody.parameters) == length(arg.parameters) == 1 || return nothing + applyTbody.parameters[1] === applyTvar || return nothing + arg.parameters[1] === tvar || return nothing + return argdef.args[3] +end + # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, # which can be very large sometimes, and program counters in question are often very sparse const SPCSet = IdSet{Int} @@ -828,6 +919,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing, InliningState} = nothin else # TODO: This isn't the best place to put these if is_known_call(stmt, typeassert, compact) canonicalize_typeassert!(compact, idx, stmt) + elseif is_known_call(stmt, Core._svec_ref, compact) + lift_svec_ref!(compact, idx, stmt) elseif is_known_call(stmt, (===), compact) lift_comparison!(===, compact, idx, stmt, lifting_cache) elseif is_known_call(stmt, isa, compact) diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index c01c0dffec505..0844550f97aef 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -173,6 +173,8 @@ function subst_trivial_bounds(@nospecialize(atype)) return UnionAll(v, subst_trivial_bounds(atype.body)) end +has_typevar(@nospecialize(t), v::TypeVar) = ccall(:jl_has_typevar, Cint, (Any, Any), t, v) != 0 + # If removing trivial vars from atype results in an equivalent type, use that # instead. Otherwise we can get a case like issue #38888, where a signature like # f(x::S) where S<:Int diff --git a/base/essentials.jl b/base/essentials.jl index 50cdde9f3adc2..d1313b0740995 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -686,13 +686,7 @@ end # SimpleVector -function getindex(v::SimpleVector, i::Int) - @boundscheck if !(1 <= i <= length(v)) - throw(BoundsError(v,i)) - end - return ccall(:jl_svec_ref, Any, (Any, Int), v, i - 1) -end - +@eval getindex(v::SimpleVector, i::Int) = Core._svec_ref($(Expr(:boundscheck)), v, i) function length(v::SimpleVector) return ccall(:jl_svec_len, Int, (Any,), v) end diff --git a/base/show.jl b/base/show.jl index 22f3f8ec0786f..4b680bc50209b 100644 --- a/base/show.jl +++ b/base/show.jl @@ -1,5 +1,7 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license +using Core.Compiler: has_typevar + function show(io::IO, ::MIME"text/plain", u::UndefInitializer) show(io, u) get(io, :compact, false) && return @@ -545,8 +547,6 @@ function print_without_params(@nospecialize(x)) return isa(b, DataType) && b.name.wrapper === x end -has_typevar(@nospecialize(t), v::TypeVar) = ccall(:jl_has_typevar, Cint, (Any, Any), t, v)!=0 - function io_has_tvar_name(io::IOContext, name::Symbol, @nospecialize(x)) for (key, val) in io.dict if key === :unionall_env && val isa TypeVar && val.name === name && has_typevar(x, val) diff --git a/src/builtin_proto.h b/src/builtin_proto.h index d9520bd251b86..f61f76c3966f8 100644 --- a/src/builtin_proto.h +++ b/src/builtin_proto.h @@ -59,6 +59,8 @@ DECLARE_BUILTIN(compilerbarrier); DECLARE_BUILTIN(getglobal); DECLARE_BUILTIN(setglobal); DECLARE_BUILTIN(finalizer); +DECLARE_BUILTIN(_compute_sparams); +DECLARE_BUILTIN(_svec_ref); JL_CALLABLE(jl_f_invoke_kwsorter); #ifdef DEFINE_BUILTIN_GLOBALS @@ -73,7 +75,8 @@ JL_CALLABLE(jl_f__setsuper); JL_CALLABLE(jl_f__equiv_typedef); JL_CALLABLE(jl_f_get_binding_type); JL_CALLABLE(jl_f_set_binding_type); - +JL_CALLABLE(jl_f__compute_sparams); +JL_CALLABLE(jl_f__svec_ref); #ifdef __cplusplus } #endif diff --git a/src/builtins.c b/src/builtins.c index a41a565b45346..56c3310d511f5 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -1626,6 +1626,36 @@ JL_CALLABLE(jl_f_finalizer) return jl_nothing; } +JL_CALLABLE(jl_f__compute_sparams) +{ + JL_NARGSV(_compute_sparams, 1); + jl_method_t *m = (jl_method_t*)args[0]; + JL_TYPECHK(_compute_sparams, method, (jl_value_t*)m); + jl_datatype_t *tt = jl_inst_arg_tuple_type(args[1], &args[2], nargs-1, 1); + jl_svec_t *env = jl_emptysvec; + JL_GC_PUSH2(&env, &tt); + jl_type_intersection_env((jl_value_t*)tt, m->sig, &env); + JL_GC_POP(); + return (jl_value_t*)env; +} + +JL_CALLABLE(jl_f__svec_ref) +{ + JL_NARGS(_svec_ref, 3, 3); + jl_value_t *b = args[0]; + jl_svec_t *s = (jl_svec_t*)args[1]; + jl_value_t *i = (jl_value_t*)args[2]; + JL_TYPECHK(_svec_ref, bool, b); + JL_TYPECHK(_svec_ref, simplevector, (jl_value_t*)s); + JL_TYPECHK(_svec_ref, long, i); + size_t len = jl_svec_len(s); + ssize_t idx = jl_unbox_long(i); + if (idx < 1 || idx > len) { + jl_bounds_error_int((jl_value_t*)s, idx); + } + return jl_svec_ref(s, idx-1); +} + static int equiv_field_types(jl_value_t *old, jl_value_t *ft) { size_t nf = jl_svec_len(ft); @@ -1998,6 +2028,8 @@ void jl_init_primitives(void) JL_GC_DISABLED jl_builtin_donotdelete = add_builtin_func("donotdelete", jl_f_donotdelete); jl_builtin_compilerbarrier = add_builtin_func("compilerbarrier", jl_f_compilerbarrier); add_builtin_func("finalizer", jl_f_finalizer); + add_builtin_func("_compute_sparams", jl_f__compute_sparams); + add_builtin_func("_svec_ref", jl_f__svec_ref); // builtin types add_builtin("Any", (jl_value_t*)jl_any_type); diff --git a/src/ircode.c b/src/ircode.c index 5dae26cfadf8c..f0e7cbd389d78 100644 --- a/src/ircode.c +++ b/src/ircode.c @@ -797,6 +797,7 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code) ios_write(s.s, (char*)jl_array_data(code->codelocs), nstmt * sizeof(int32_t)); } + write_uint8(s.s, code->has_fcall); write_uint8(s.s, s.relocatability); ios_flush(s.s); @@ -809,6 +810,7 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code) jl_gc_enable(en); JL_UNLOCK(&m->writelock); // Might GC JL_GC_POP(); + return v; } @@ -878,6 +880,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t ios_readall(s.s, (char*)jl_array_data(code->codelocs), nstmt * sizeof(int32_t)); } + code->has_fcall = read_uint8(s.s); (void) read_uint8(s.s); // relocatability assert(ios_getc(s.s) == -1); @@ -892,6 +895,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t code->rettype = metadata->rettype; code->parent = metadata->def; } + return code; } diff --git a/src/jltypes.c b/src/jltypes.c index 2dc185db27c9b..9da7209f17f2c 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2401,7 +2401,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_code_info_type = jl_new_datatype(jl_symbol("CodeInfo"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(20, + jl_perm_symsvec(21, "code", "codelocs", "ssavaluetypes", @@ -2420,9 +2420,10 @@ void jl_init_types(void) JL_GC_DISABLED "inlining_cost", "propagate_inbounds", "pure", + "has_fcall", "constprop", "purity"), - jl_svec(20, + jl_svec(21, jl_array_any_type, jl_array_int32_type, jl_any_type, @@ -2441,6 +2442,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_uint16_type, jl_bool_type, jl_bool_type, + jl_bool_type, jl_uint8_type, jl_uint8_type), jl_emptysvec, diff --git a/src/julia.h b/src/julia.h index f780522081c3c..9179e80c3083f 100644 --- a/src/julia.h +++ b/src/julia.h @@ -283,6 +283,7 @@ typedef struct _jl_code_info_t { uint16_t inlining_cost; uint8_t propagate_inbounds; uint8_t pure; + uint8_t has_fcall; // uint8 settings uint8_t constprop; // 0 = use heuristic; 1 = aggressive; 2 = none _jl_purity_overrides_t purity; diff --git a/src/method.c b/src/method.c index e5d6771b080d6..52748ede0d35a 100644 --- a/src/method.c +++ b/src/method.c @@ -377,7 +377,9 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir) else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == jl_return_sym) { jl_array_ptr_set(body, j, jl_new_struct(jl_returnnode_type, jl_exprarg(st, 0))); } - + else if (jl_is_expr(st) && (((jl_expr_t*)st)->head == jl_foreigncall_sym || ((jl_expr_t*)st)->head == jl_cfunction_sym)) { + li->has_fcall = 1; + } if (is_flag_stmt) jl_array_uint8_set(li->ssaflags, j, 0); else { @@ -471,6 +473,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void) src->inlining_cost = UINT16_MAX; src->propagate_inbounds = 0; src->pure = 0; + src->has_fcall = 0; src->edges = jl_nothing; src->constprop = 0; src->purity.bits = 0; diff --git a/src/staticdata.c b/src/staticdata.c index 2e0c69dad3afc..1b5fc9bbe3edc 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -315,7 +315,7 @@ static const jl_fptr_args_t id_to_fptrs[] = { &jl_f_ifelse, &jl_f__structtype, &jl_f__abstracttype, &jl_f__primitivetype, &jl_f__typebody, &jl_f__setsuper, &jl_f__equiv_typedef, &jl_f_get_binding_type, &jl_f_set_binding_type, &jl_f_opaque_closure_call, &jl_f_donotdelete, &jl_f_compilerbarrier, - &jl_f_getglobal, &jl_f_setglobal, &jl_f_finalizer, + &jl_f_getglobal, &jl_f_setglobal, &jl_f_finalizer, &jl_f__compute_sparams, &jl_f__svec_ref, NULL }; typedef struct { diff --git a/stdlib/Serialization/src/Serialization.jl b/stdlib/Serialization/src/Serialization.jl index a5c9fec8080b7..538d9e0cd101d 100644 --- a/stdlib/Serialization/src/Serialization.jl +++ b/stdlib/Serialization/src/Serialization.jl @@ -79,7 +79,7 @@ const TAGS = Any[ @assert length(TAGS) == 255 -const ser_version = 19 # do not make changes without bumping the version #! +const ser_version = 20 # do not make changes without bumping the version #! format_version(::AbstractSerializer) = ser_version format_version(s::Serializer) = s.version @@ -1191,6 +1191,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo}) end ci.propagate_inbounds = deserialize(s) ci.pure = deserialize(s) + if format_version(s) >= 20 + ci.has_fcall = deserialize(s) + end if format_version(s) >= 14 ci.constprop = deserialize(s)::UInt8 end diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index c94e12418e9a8..2263a538f1eaa 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -1394,9 +1394,7 @@ let src = code_typed1(Tuple{Any}) do x DoAllocNoEscapeSparam(x) end end - # This requires more inlining enhancments. For now just make sure this - # doesn't error. - @test count(isnew, src.code) in (0, 1) # == 0 + @test count(isnew, src.code) == 0 end # Test noinline finalizer @@ -1519,3 +1517,27 @@ function oc_capture_oc(z) return oc2(z) end @test fully_eliminated(oc_capture_oc, (Int,)) + +@testset "Inlining with unmatched type parameters" begin + @eval struct OldVal{T} + x::T + (OV::Type{OldVal{T}})() where T = $(Expr(:new, :OV)) + end + let f(x) = OldVal{x}() + g() = [ Base.donotdelete(OldVal{i}()) for i in 1:10000 ] + h() = begin + f(x::OldVal{i}) where {i} = i + r = 0 + for i = 1:10000 + r += f(OldVal{i}()) + end + return r + end + srcs = (code_typed1(f, (Any,)), + code_typed1(g), + code_typed1(h)) + for src in srcs + @test !any(@nospecialize(x) -> isexpr(x, :call) && length(x.args) == 1, src.code) + end + end +end diff --git a/test/operators.jl b/test/operators.jl index cd0ca743d2cda..8192e13b73a7f 100644 --- a/test/operators.jl +++ b/test/operators.jl @@ -325,7 +325,7 @@ end @test Returns(val)(1) === val @test sprint(show, Returns(1.0)) == "Returns{Float64}(1.0)" - illtype = Vector{Core._typevar(:T, Union{}, Any)} + illtype = Vector{Core.TypeVar(:T)} @test Returns(illtype) == Returns{DataType}(illtype) end diff --git a/test/testhelpers/OffsetArrays.jl b/test/testhelpers/OffsetArrays.jl index 01b34df8e18a9..705bd07b2878c 100644 --- a/test/testhelpers/OffsetArrays.jl +++ b/test/testhelpers/OffsetArrays.jl @@ -100,7 +100,7 @@ end # function offset_coerce(::Type{Base.OneTo{T}}, r::IdOffsetRange) where T<:Integer # rc, o = offset_coerce(Base.OneTo{T}, r.parent) -# Fallback, specialze this method if `convert(I, r)` doesn't do what you need +# Fallback, specialize this method if `convert(I, r)` doesn't do what you need offset_coerce(::Type{I}, r::AbstractUnitRange) where I<:AbstractUnitRange = convert(I, r)::I, 0