From 5320bd974f905068d6d0df31eec5db8380b7a947 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Fri, 3 May 2024 04:42:45 +0000 Subject: [PATCH] Move nargs/isva to CodeInfo This changes the canonical source of truth for va handling from `Method` to `CodeInfo`. There are multiple goals for this change: 1. This addresses a longstanding complaint about the way that CodeInfo-returning generated functions work. Previously, the va-ness or not of the returned CodeInfo always had to match that of the generator. For Cassette-like transforms that generally have one big generator function that is varargs (while then looking up lowered code that is not varargs), this could become quite annoying. It's possible to workaround, but there is really no good reason to tie the two together. As we observed when we implemented OpaqueClosures, the vararg-ness of the signature and the `vararg arguments`->`tuple` transformation are mostly independent concepts. With this PR, generated functions can return CodeInfos with whatever combination of nargs/isva is convenient. 2. This change requires clarifying where the va processing boundary is in inference. #54076 was already moving in that direction for irinterp, and this essentially does much of the same for regular inference. As a consequence the constprop cache is now using non-va-cooked signatures, which I think is preferable. 3. This further decouples codegen from the presence of a `Method` (which is already not assumed, since the code being generated could be a toplevel thunk, but some codegen features are only available to things that come from Methods). There are a number of upcoming features that will require codegen of things that are not quite method specializations (See design doc linked in #52797 and things like #50641). This helps pave the road for that. 4. I've previously considered expanding the kinds of vararg signatures that can be described (see e.g. #53851), which also requires a decoupling of the signature and ast notions of vararg. This again lays the groundwork for that, although I have no immediate plans to implement this change. Impact wise, this adds an internal field, which is not too breaking, but downstream clients vary in how they construct their `CodeInfo`s and the current way they're doing it will likely be incorrect after this change, so they will require a small two-line adjustment. We should perhaps consider pulling out some of the more common patterns into a more stable package, since interface in most of the last few releases, but that's a separate issue. --- base/compiler/abstractinterpretation.jl | 43 ++--- base/compiler/inferenceresult.jl | 183 +++++++++++----------- base/compiler/inferencestate.jl | 10 +- base/compiler/optimize.jl | 5 +- base/compiler/ssair/legacy.jl | 8 +- base/compiler/ssair/slot2ssa.jl | 2 +- base/compiler/typeinfer.jl | 2 +- base/compiler/types.jl | 3 - src/codegen.cpp | 24 ++- src/ircode.c | 16 +- src/julia_internal.h | 1 + src/method.c | 9 ++ stdlib/Serialization/src/Serialization.jl | 8 +- stdlib/Test/src/precompile.jl | 6 +- test/compiler/contextual.jl | 22 +++ 15 files changed, 189 insertions(+), 153 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 3b5d01d8b15569..f6fefe8f54c274 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -479,7 +479,12 @@ function conditional_argtype(๐•ƒแตข::AbstractLattice, @nospecialize(rt), @nospe if isa(rt, InterConditional) && rt.slot == i return rt else - thentype = elsetype = tmeet(๐•ƒแตข, widenslotwrapper(argtypes[i]), fieldtype(sig, i)) + argt = widenslotwrapper(argtypes[i]) + if isvarargtype(argt) + @assert fieldcount(sig) == i + argt = unwrapva(argt) + end + thentype = elsetype = tmeet(๐•ƒแตข, argt, fieldtype(sig, i)) condval = maybe_extract_const_bool(rt) condval === true && (elsetype = Bottom) condval === false && (thentype = Bottom) @@ -986,15 +991,12 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, # N.B. remarks are emitted within `const_prop_entry_heuristic` return nothing end - nargs::Int = method.nargs - method.isva && (nargs -= 1) - length(arginfo.argtypes) < nargs && return nothing if !const_prop_argument_heuristic(interp, arginfo, sv) add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics") return nothing end all_overridden = is_all_overridden(interp, arginfo, sv) - if !force && !const_prop_function_heuristic(interp, f, arginfo, nargs, all_overridden, sv) + if !force && !const_prop_function_heuristic(interp, f, arginfo, all_overridden, sv) add_remark!(interp, sv, "[constprop] Disabled by function heuristic") return nothing end @@ -1113,9 +1115,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method: end function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), - arginfo::ArgInfo, nargs::Int, all_overridden::Bool, sv::AbsIntState) + arginfo::ArgInfo, all_overridden::Bool, sv::AbsIntState) argtypes = arginfo.argtypes - if nargs > 1 + if length(argtypes) > 1 ๐•ƒแตข = typeinf_lattice(interp) if istopfunction(f, :getindex) || istopfunction(f, :setindex!) arrty = argtypes[2] @@ -1349,20 +1351,6 @@ function matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, end given_argtypes[i] = widenslotwrapper(argtype) end - if condargs !== nothing - given_argtypes = let condargs=condargs - va_process_argtypes(๐•ƒ, given_argtypes, mi) do isva_given_argtypes::Vector{Any}, last::Int - # invalidate `Conditional` imposed on varargs - for (slotid, i) in condargs - if slotid โ‰ฅ last && (1 โ‰ค i โ‰ค length(isva_given_argtypes)) # `Conditional` is already widened to vararg-tuple otherwise - isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i]) - end - end - end - end - else - given_argtypes = va_process_argtypes(๐•ƒ, given_argtypes, mi) - end return pick_const_args!(๐•ƒ, given_argtypes, cache_argtypes) end @@ -1721,7 +1709,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si:: return CallMeta(res, exct, effects, retinfo) end -function argtype_by_index(argtypes::Vector{Any}, i::Int) +function argtype_by_index(argtypes::Vector{Any}, i::Integer) n = length(argtypes) na = argtypes[n] if isvarargtype(na) @@ -2890,12 +2878,12 @@ end struct BestguessInfo{Interp<:AbstractInterpreter} interp::Interp bestguess - nargs::Int + nargs::UInt slottypes::Vector{Any} changes::VarTable - function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::Int, + function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::UInt, slottypes::Vector{Any}, changes::VarTable) where Interp<:AbstractInterpreter - new{Interp}(interp, bestguess, nargs, slottypes, changes) + new{Interp}(interp, bestguess, Int(nargs), slottypes, changes) end end @@ -2970,7 +2958,7 @@ end # pick up the first "interesting" slot, convert `rt` to its `Conditional` # TODO: ideally we want `Conditional` and `InterConditional` to convey # constraints on multiple slots - for slot_id = 1:info.nargs + for slot_id = 1:Int(info.nargs) rt = bool_rt_to_conditional(rt, slot_id, info) rt isa InterConditional && break end @@ -2981,6 +2969,9 @@ end โŠ‘แตข = โŠ‘(typeinf_lattice(info.interp)) old = info.slottypes[slot_id] new = widenslotwrapper(info.changes[slot_id].typ) # avoid nested conditional + if isvarargtype(old) || isvarargtype(new) + return rt + end if new โŠ‘แตข old && !(old โŠ‘แตข new) if isa(rt, Const) val = rt.val diff --git a/base/compiler/inferenceresult.jl b/base/compiler/inferenceresult.jl index 2575429fbf924f..7ebbb381d4cfbc 100644 --- a/base/compiler/inferenceresult.jl +++ b/base/compiler/inferenceresult.jl @@ -24,27 +24,53 @@ function matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, for i = 1:length(argtypes) given_argtypes[i] = widenslotwrapper(argtypes[i]) end - given_argtypes = va_process_argtypes(๐•ƒ, given_argtypes, mi) return pick_const_args!(๐•ƒ, given_argtypes, cache_argtypes) end +function pick_const_arg(๐•ƒ::AbstractLattice, @nospecialize(given_argtype), @nospecialize(cache_argtype)) + if !is_argtype_match(๐•ƒ, given_argtype, cache_argtype, false) + # prefer the argtype we were given over the one computed from `mi` + if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) && + !โŠ(๐•ƒ, given_argtype, cache_argtype)) + # if the type information of this `PartialStruct` is less strict than + # declared method signature, narrow it down using `tmeet` + given_argtype = tmeet(๐•ƒ, given_argtype, cache_argtype) + end + else + given_argtype = cache_argtype + end + return given_argtype +end + function pick_const_args!(๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any}) - nargtypes = length(given_argtypes) - @assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`" - for i = 1:nargtypes - given_argtype = given_argtypes[i] - cache_argtype = cache_argtypes[i] - if !is_argtype_match(๐•ƒ, given_argtype, cache_argtype, false) - # prefer the argtype we were given over the one computed from `mi` - if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) && - !โŠ(๐•ƒ, given_argtype, cache_argtype)) - # if the type information of this `PartialStruct` is less strict than - # declared method signature, narrow it down using `tmeet` - given_argtypes[i] = tmeet(๐•ƒ, given_argtype, cache_argtype) - end + if length(given_argtypes) == 0 || length(cache_argtypes) == 0 + return Any[] + end + given_va = given_argtypes[end] + cache_va = cache_argtypes[end] + if isvarargtype(given_va) + if isvarargtype(cache_va) + # Process the common prefix, then join + nprocessargs = max(length(given_argtypes)-1, length(cache_argtypes)-1) + resize!(given_argtypes, nprocessargs+1) + given_argtypes[end] = Vararg{pick_const_arg(๐•ƒ, unwrapva(given_va), unwrapva(cache_va))} else - given_argtypes[i] = cache_argtype + nprocessargs = length(cache_argtypes) + resize!(given_argtypes, nprocessargs) end + elseif isvarargtype(cache_va) + nprocessargs = length(given_argtypes) + resize!(given_argtypes, nprocessargs) + else + @assert length(given_argtypes) == length(cache_argtypes) + nprocessargs = length(given_argtypes) + resize!(given_argtypes, nprocessargs) + end + for i = 1:nprocessargs + given_argtype = argtype_by_index(given_argtypes, i) + cache_argtype = argtype_by_index(cache_argtypes, i) + given_argtype = pick_const_arg(๐•ƒ, given_argtype, cache_argtype) + given_argtypes[i] = given_argtype end return given_argtypes end @@ -60,25 +86,33 @@ function is_argtype_match(๐•ƒ::AbstractLattice, end end -va_process_argtypes(๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) = - va_process_argtypes(Returns(nothing), ๐•ƒ, given_argtypes, mi) -function va_process_argtypes(@specialize(va_handler!), ๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) - def = mi.def::Method - isva = def.isva - nargs = Int(def.nargs) - if isva || isvarargtype(given_argtypes[end]) - isva_given_argtypes = Vector{Any}(undef, nargs) +function va_process_argtypes(๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, nargs::UInt, isva::Bool) + if isva || (!isempty(given_argtypes) && isvarargtype(given_argtypes[end])) + isva_given_argtypes = Vector{Any}(undef, Int(nargs)) for i = 1:(nargs-isva) - isva_given_argtypes[i] = argtype_by_index(given_argtypes, i) + newarg = argtype_by_index(given_argtypes, i) + if isva && has_conditional(๐•ƒ) && isa(newarg, Conditional) + if newarg.slotid > (nargs-isva) + newarg = widenconditional(newarg) + end + end + isva_given_argtypes[i] = newarg end if isva if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end]) last = length(given_argtypes) else last = nargs + if has_conditional(๐•ƒ) + for i = last:length(given_argtypes) + newarg = given_argtypes[i] + if isa(newarg, Conditional) && newarg.slotid > (nargs-isva) + given_argtypes[i] = widenconditional(newarg) + end + end + end end isva_given_argtypes[nargs] = tuple_tfunc(๐•ƒ, given_argtypes[last:end]) - va_handler!(isva_given_argtypes, last) end return isva_given_argtypes end @@ -87,84 +121,44 @@ function va_process_argtypes(@specialize(va_handler!), ๐•ƒ::AbstractLattice, gi end function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes)) - toplevel = method === nothing - isva = !toplevel && method.isva mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...] - nargs::Int = toplevel ? 0 : method.nargs - cache_argtypes = Vector{Any}(undef, nargs) - # First, if we're dealing with a varargs method, then we set the last element of `args` - # to the appropriate `Tuple` type or `PartialStruct` instance. - mi_argtypes_length = length(mi_argtypes) - if !toplevel && isva - if specTypes::Type == Tuple - mi_argtypes = Any[Any for i = 1:nargs] - if nargs > 1 - mi_argtypes[end] = Tuple - end - vargtype = Tuple - else - if nargs > mi_argtypes_length - va = mi_argtypes[mi_argtypes_length] - if isvarargtype(va) - new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes) - vargtype = Tuple{new_va} - else - vargtype = Tuple{} - end - else - vargtype_elements = Any[] - for i in nargs:mi_argtypes_length - p = mi_argtypes[i] - p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p) - push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes))) - end - for i in 1:length(vargtype_elements) - atyp = vargtype_elements[i] - if issingletontype(atyp) - # replace singleton types with their equivalent Const object - vargtype_elements[i] = Const(atyp.instance) - elseif isconstType(atyp) - vargtype_elements[i] = Const(atyp.parameters[1]) - end - end - vargtype = tuple_tfunc(fallback_lattice, vargtype_elements) - end - end - cache_argtypes[nargs] = vargtype - nargs -= 1 + nargtypes = length(mi_argtypes) + nargs = isa(method, Method) ? method.nargs : 0 + if length(mi_argtypes) < nargs && isvarargtype(mi_argtypes[end]) + resize!(mi_argtypes, nargs) end # Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some # type info as we go (where possible). Note that if we're dealing with a varargs method, # we already handled the last element of `cache_argtypes` (and decremented `nargs` so that # we don't overwrite the result of that work here). - if mi_argtypes_length > 0 - tail_index = nargtypes = min(mi_argtypes_length, nargs) - local lastatype - for i = 1:nargtypes - atyp = mi_argtypes[i] - if i == nargtypes && isvarargtype(atyp) - atyp = unwrapva(atyp) - tail_index -= 1 - end - atyp = unwraptv(atyp) - if issingletontype(atyp) - # replace singleton types with their equivalent Const object - atyp = Const(atyp.instance) - elseif isconstType(atyp) - atyp = Const(atyp.parameters[1]) - else - atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes)) - end - i == nargtypes && (lastatype = atyp) - cache_argtypes[i] = atyp + tail_index = min(nargtypes, nargs) + local lastatype + for i = 1:nargtypes + atyp = mi_argtypes[i] + wasva = false + if i == nargtypes && isvarargtype(atyp) + wasva = true + atyp = unwrapva(atyp) end - for i = (tail_index+1):nargs - cache_argtypes[i] = lastatype + atyp = unwraptv(atyp) + if issingletontype(atyp) + # replace singleton types with their equivalent Const object + atyp = Const(atyp.instance) + elseif isconstType(atyp) + atyp = Const(atyp.parameters[1]) + else + atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes)) end - else - @assert nargs == 0 "invalid specialization of method" # wrong number of arguments + mi_argtypes[i] = atyp + if wasva + lastatype = atyp + mi_argtypes[end] = Vararg{atyp} + end + end + for i = (tail_index+1):(nargs-1) + mi_argtypes[i] = lastatype end - return cache_argtypes + return mi_argtypes end # eliminate free `TypeVar`s in order to make the life much easier down the road: @@ -184,7 +178,6 @@ function cache_lookup(๐•ƒ::AbstractLattice, mi::MethodInstance, given_argtypes: cache::Vector{InferenceResult}) method = mi.def::Method nargtypes = length(given_argtypes) - @assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`" for cached_result in cache cached_result.linfo === mi || @goto next_cache cache_argtypes = cached_result.argtypes diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 169d543f3249c7..4e72d5036464e0 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -302,6 +302,9 @@ mutable struct InferenceState bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ] bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots) argtypes = result.argtypes + + argtypes = va_process_argtypes(typeinf_lattice(interp), argtypes, src.nargs, src.isva) + nargtypes = length(argtypes) for i = 1:nslots argtyp = (i > nargtypes) ? Bottom : argtypes[i] @@ -766,10 +769,9 @@ function print_callstack(sv::InferenceState) end function narguments(sv::InferenceState, include_va::Bool=true) - def = sv.linfo.def - nargs = length(sv.result.argtypes) + nargs = sv.src.nargs if !include_va - nargs -= isa(def, Method) && def.isva + nargs -= sv.src.isva end return nargs end @@ -831,7 +833,7 @@ function IRInterpretationState(interp::AbstractInterpreter, end method_info = MethodInfo(src) ir = inflate_ir(src, mi) - argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, mi) + argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, src.nargs, src.isva) return IRInterpretationState(interp, method_info, ir, mi, argtypes, world, codeinst.min_world, codeinst.max_world) end diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 85942d3ca83b38..6c956e56b89adc 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -1264,14 +1264,13 @@ end function slot2reg(ir::IRCode, ci::CodeInfo, sv::OptimizationState) # need `ci` for the slot metadata, IR for the code svdef = sv.linfo.def - nargs = isa(svdef, Method) ? Int(svdef.nargs) : 0 @timeit "domtree 1" domtree = construct_domtree(ir) - defuse_insts = scan_slot_def_use(nargs, ci, ir.stmts.stmt) + defuse_insts = scan_slot_def_use(ci.nargs, ci, ir.stmts.stmt) ๐•ƒโ‚’ = optimizer_lattice(sv.inlining.interp) @timeit "construct_ssa" ir = construct_ssa!(ci, ir, sv, domtree, defuse_insts, ๐•ƒโ‚’) # consumes `ir` # NOTE now we have converted `ir` to the SSA form and eliminated slots # let's resize `argtypes` now and remove unnecessary types for the eliminated slots - resize!(ir.argtypes, nargs) + resize!(ir.argtypes, ci.nargs) return ir end diff --git a/base/compiler/ssair/legacy.jl b/base/compiler/ssair/legacy.jl index b45db03875801c..2b0721b8d24084 100644 --- a/base/compiler/ssair/legacy.jl +++ b/base/compiler/ssair/legacy.jl @@ -10,7 +10,13 @@ the original `ci::CodeInfo` are modified. """ function inflate_ir!(ci::CodeInfo, mi::MethodInstance) sptypes = sptypes_from_meth_instance(mi) - argtypes = matching_cache_argtypes(fallback_lattice, mi) + if ci.slottypes === nothing + argtypes = va_process_argtypes(fallback_lattice, + matching_cache_argtypes(fallback_lattice, mi), + ci.nargs, ci.isva) + else + argtypes = ci.slottypes[1:ci.nargs] + end return inflate_ir!(ci, sptypes, argtypes) end function inflate_ir!(ci::CodeInfo, sptypes::Vector{VarState}, argtypes::Vector{Any}) diff --git a/base/compiler/ssair/slot2ssa.jl b/base/compiler/ssair/slot2ssa.jl index f2bfa0e4c54767..4c928dd8f6b725 100644 --- a/base/compiler/ssair/slot2ssa.jl +++ b/base/compiler/ssair/slot2ssa.jl @@ -33,7 +33,7 @@ function scan_entry!(result::Vector{SlotInfo}, idx::Int, @nospecialize(stmt)) end end -function scan_slot_def_use(nargs::Int, ci::CodeInfo, code::Vector{Any}) +function scan_slot_def_use(nargs::UInt, ci::CodeInfo, code::Vector{Any}) nslots = length(ci.slotflags) result = SlotInfo[SlotInfo() for i = 1:nslots] # Set defs for arguments diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index d091fb8d2f5f8b..ee3e93806f8531 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -468,7 +468,7 @@ function adjust_effects(sv::InferenceState) # this frame is known to be safe ipo_effects = Effects(ipo_effects; nothrow=true) end - if is_inaccessiblemem_or_argmemonly(ipo_effects) && all(1:narguments(sv, #=include_va=#true)) do i::Int + if is_inaccessiblemem_or_argmemonly(ipo_effects) && all(1:narguments(sv, #=include_va=#true)) do i::UInt return is_mutation_free_argtype(sv.slottypes[i]) end ipo_effects = Effects(ipo_effects; inaccessiblememonly=ALWAYS_TRUE) diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 30cb0fb0f39c55..a6f5488ef67034 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -91,9 +91,6 @@ mutable struct InferenceResult is_src_volatile::Bool # `src` has been cached globally as the compressed format already, allowing `src` to be used destructively ci::CodeInstance # CodeInstance if this result has been added to the cache function InferenceResult(mi::MethodInstance, argtypes::Vector{Any}, overridden_by_const::Union{Nothing,BitVector}) - def = mi.def - nargs = def isa Method ? Int(def.nargs) : 0 - @assert length(argtypes) == nargs "invalid `argtypes` for `mi`" return new(mi, argtypes, overridden_by_const, nothing, nothing, nothing, WorldRange(), Effects(), Effects(), NULL_ANALYSIS_RESULTS, false) end diff --git a/src/codegen.cpp b/src/codegen.cpp index 03e138dac0394a..65c0bf6db6e016 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -8085,19 +8085,15 @@ static jl_llvm_functions_t ctx.module = jl_is_method(lam->def.method) ? lam->def.method->module : lam->def.module; ctx.linfo = lam; ctx.name = TSM.getModuleUnlocked()->getModuleIdentifier().data(); - size_t nreq = 0; - int va = 0; - if (jl_is_method(lam->def.method)) { - ctx.nargs = nreq = lam->def.method->nargs; - ctx.is_opaque_closure = lam->def.method->is_for_opaque_closure; - if ((nreq > 0 && jl_is_method(lam->def.value) && lam->def.method->isva)) { - assert(nreq > 0); - nreq--; - va = 1; - } + size_t nreq = src->nargs; + int va = src->isva; + ctx.nargs = nreq; + if (va) { + assert(nreq > 0); + nreq--; } - else { - ctx.nargs = 0; + if (jl_is_method(lam->def.value)) { + ctx.is_opaque_closure = lam->def.method->is_for_opaque_closure; } ctx.nReqArgs = nreq; if (va) { @@ -8161,7 +8157,7 @@ static jl_llvm_functions_t // step 3. some variable analysis size_t i; - for (i = 0; i < nreq; i++) { + for (i = 0; i < nreq && i < vinfoslen; i++) { jl_varinfo_t &varinfo = ctx.slots[i]; varinfo.isArgument = true; jl_sym_t *argname = slot_symbol(ctx, i); @@ -8675,7 +8671,7 @@ static jl_llvm_functions_t AttrBuilder param(ctx.builder.getContext()); attrs[Arg->getArgNo()] = AttributeSet::get(Arg->getContext(), param); } - for (i = 0; i < nreq; i++) { + for (i = 0; i < nreq && i < vinfoslen; i++) { jl_sym_t *s = slot_symbol(ctx, i); jl_value_t *argType = jl_nth_slot_type(lam->specTypes, i); // TODO: jl_nth_slot_type should call jl_rewrap_unionall? diff --git a/src/ircode.c b/src/ircode.c index 2e16d1b5b24208..9c6944bccc9146 100644 --- a/src/ircode.c +++ b/src/ircode.c @@ -458,12 +458,14 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal) } static jl_code_info_flags_t code_info_flags(uint8_t propagate_inbounds, uint8_t has_fcall, - uint8_t nospecializeinfer, uint8_t inlining, uint8_t constprop) + uint8_t nospecializeinfer, uint8_t isva, + uint8_t inlining, uint8_t constprop) { jl_code_info_flags_t flags; flags.bits.propagate_inbounds = propagate_inbounds; flags.bits.has_fcall = has_fcall; flags.bits.nospecializeinfer = nospecializeinfer; + flags.bits.isva = isva; flags.bits.inlining = inlining; flags.bits.constprop = constprop; return flags; @@ -860,7 +862,8 @@ JL_DLLEXPORT jl_string_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code) }; jl_code_info_flags_t flags = code_info_flags(code->propagate_inbounds, code->has_fcall, - code->nospecializeinfer, code->inlining, code->constprop); + code->nospecializeinfer, code->isva, + code->inlining, code->constprop); write_uint8(s.s, flags.packed); static_assert(sizeof(flags.packed) == IR_DATASIZE_FLAGS, "ir_datasize_flags is mismatched with the actual size"); write_uint16(s.s, code->purity.bits); @@ -868,6 +871,8 @@ JL_DLLEXPORT jl_string_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code) write_uint16(s.s, code->inlining_cost); static_assert(sizeof(code->inlining_cost) == IR_DATASIZE_INLINING_COST, "ir_datasize_inlining_cost is mismatched with the actual size"); + write_int32(s.s, code->nargs); + int32_t nslots = jl_array_nrows(code->slotflags); assert(nslots >= m->nargs && nslots < INT32_MAX); // required by generated functions write_int32(s.s, nslots); @@ -892,6 +897,8 @@ JL_DLLEXPORT jl_string_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code) if (m->is_for_opaque_closure) jl_encode_value_(&s, code->slottypes, 1); + // Slotnames. For regular methods, we require that m->slot_syms matches the + // CodeInfo's slotnames, so we do not need to save it here. if (m->generator) // can't optimize generated functions jl_encode_value_(&s, (jl_value_t*)jl_compress_argnames(code->slotnames), 1); @@ -943,10 +950,13 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t code->propagate_inbounds = flags.bits.propagate_inbounds; code->has_fcall = flags.bits.has_fcall; code->nospecializeinfer = flags.bits.nospecializeinfer; + code->isva = flags.bits.isva; code->purity.bits = read_uint16(s.s); code->inlining_cost = read_uint16(s.s); - size_t nslots = read_int32(&src); + code->nargs = read_int32(s.s); + + size_t nslots = read_int32(s.s); code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots); ios_readall(s.s, jl_array_data(code->slotflags, char), nslots); diff --git a/src/julia_internal.h b/src/julia_internal.h index 900591fa206a77..4935b729861d2b 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -632,6 +632,7 @@ typedef struct { uint8_t propagate_inbounds:1; uint8_t has_fcall:1; uint8_t nospecializeinfer:1; + uint8_t isva:1; uint8_t inlining:2; // 0 = use heuristic; 1 = aggressive; 2 = none uint8_t constprop:2; // 0 = use heuristic; 1 = aggressive; 2 = none } jl_code_info_flags_bitfield_t; diff --git a/src/method.c b/src/method.c index c6177fa9dd0556..e78c3dd0629553 100644 --- a/src/method.c +++ b/src/method.c @@ -420,6 +420,10 @@ jl_code_info_t *jl_new_code_info_from_ir(jl_expr_t *ir) jl_code_info_t *li = NULL; JL_GC_PUSH1(&li); li = jl_new_code_info_uninit(); + + jl_expr_t *arglist = (jl_expr_t*)jl_exprarg(ir, 0); + li->nargs = jl_array_len(arglist); + assert(jl_is_expr(ir)); jl_expr_t *bodyex = (jl_expr_t*)jl_exprarg(ir, 2); @@ -737,6 +741,9 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo, siz } jl_error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator."); } + // TODO: This should ideally be in the lambda expression, + // but currently our isva determination is non-syntactic + func->isva = def->isva; } // If this generated function has an opaque closure, cache it for @@ -899,6 +906,8 @@ JL_DLLEXPORT void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) jl_array_ptr_set(copy, i, st); } src = jl_copy_code_info(src); + src->isva = m->isva; // TODO: It would be nice to reverse this + assert(m->nargs == src->nargs); src->code = copy; jl_gc_wb(src, copy); m->slot_syms = jl_compress_argnames(src->slotnames); diff --git a/stdlib/Serialization/src/Serialization.jl b/stdlib/Serialization/src/Serialization.jl index c9cb1edea4b6be..765157025806f6 100644 --- a/stdlib/Serialization/src/Serialization.jl +++ b/stdlib/Serialization/src/Serialization.jl @@ -80,7 +80,7 @@ const TAGS = Any[ const NTAGS = length(TAGS) @assert NTAGS == 255 -const ser_version = 27 # do not make changes without bumping the version #! +const ser_version = 28 # do not make changes without bumping the version #! format_version(::AbstractSerializer) = ser_version format_version(s::Serializer) = s.version @@ -1260,6 +1260,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo}) ci.inlining_cost = inlining_cost end end + if format_version(s) >= 28 + ci.nargs = deserialize(s) + end ci.propagate_inbounds = deserialize(s) if format_version(s) < 23 deserialize(s) # `pure` field has been removed @@ -1270,6 +1273,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo}) if format_version(s) >= 24 ci.nospecializeinfer = deserialize(s)::Bool end + if format_version(s) >= 28 + ci.isva = deserialize(s)::Bool + end if format_version(s) >= 21 ci.inlining = deserialize(s)::UInt8 end diff --git a/stdlib/Test/src/precompile.jl b/stdlib/Test/src/precompile.jl index 1e53033a091432..04907f84254400 100644 --- a/stdlib/Test/src/precompile.jl +++ b/stdlib/Test/src/precompile.jl @@ -1,5 +1,6 @@ if Base.generating_output() - redirect_stdout(devnull) do +let + function example_payload() @testset "example" begin @test 1 == 1 @test_throws ErrorException error() @@ -8,4 +9,7 @@ if Base.generating_output() @test 1 โ‰ˆ 1.0000000000000001 end end + + redirect_stdout(example_payload, devnull) +end end diff --git a/test/compiler/contextual.jl b/test/compiler/contextual.jl index 17919c6a9f6add..9fb8f66b709f55 100644 --- a/test/compiler/contextual.jl +++ b/test/compiler/contextual.jl @@ -222,3 +222,25 @@ end end @test_throws "oh no" doit49715(sin, Tuple{Int}) + +# Test that the CodeInfo returned from generated function need not match the +# generator. +function overdubbee(a, b) + a + b +end +const overdubee_codeinfo = code_lowered(overdubbee, Tuple{Any, Any})[1] + +function overdub_generator(world::UInt, source::LineNumberNode, args...) + if length(args) != 2 + :(error("Wrong number of arguments")) + else + return copy(overdubee_codeinfo) + end +end + +@eval function overdub(args...) + $(Expr(:meta, :generated, overdub_generator)) + $(Expr(:meta, :generated_only)) +end + +@test overdub(1, 2) == 3