Skip to content

Commit

Permalink
Revert "generators: expose caller world to GeneratedFunctionStub (#48611
Browse files Browse the repository at this point in the history
)" (#48763)

This reverts commit e3d366f.
  • Loading branch information
maleadt authored Feb 23, 2023
1 parent b600f51 commit 7823552
Show file tree
Hide file tree
Showing 31 changed files with 203 additions and 203 deletions.
2 changes: 1 addition & 1 deletion base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ in_sysimage(pkgid::PkgId) = pkgid in _sysimage_modules
for match = _methods(+, (Int, Int), -1, get_world_counter())
m = match.method
delete!(push!(Set{Method}(), m), m)
copy(Core.Compiler.retrieve_code_info(Core.Compiler.specialize_method(match), typemax(UInt)))
copy(Core.Compiler.retrieve_code_info(Core.Compiler.specialize_method(match)))

empty!(Set())
push!(push!(Set{Union{GlobalRef,Symbol}}(), :two), GlobalRef(Base, :two))
Expand Down
27 changes: 15 additions & 12 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -590,25 +590,28 @@ println(@nospecialize a...) = println(stdout, a...)

struct GeneratedFunctionStub
gen
argnames::SimpleVector
spnames::SimpleVector
argnames::Array{Any,1}
spnames::Union{Nothing, Array{Any,1}}
line::Int
file::Symbol
expand_early::Bool
end

# invoke and wrap the results of @generated expression
function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospecialize args...)
# args is (spvals..., argtypes...)
# invoke and wrap the results of @generated
function (g::GeneratedFunctionStub)(@nospecialize args...)
body = g.gen(args...)
file = source.file
file isa Symbol || (file = :none)
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
Expr(:var"scope-block",
if body isa CodeInfo
return body
end
lam = Expr(:lambda, g.argnames,
Expr(Symbol("scope-block"),
Expr(:block,
source,
Expr(:meta, :push_loc, file, :var"@generated body"),
LineNumberNode(g.line, g.file),
Expr(:meta, :push_loc, g.file, Symbol("@generated body")),
Expr(:return, body),
Expr(:meta, :pop_loc))))
spnames = g.spnames
if spnames === svec()
if spnames === nothing
return lam
else
return Expr(Symbol("with-static-parameters"), lam, spnames...)
Expand Down
13 changes: 6 additions & 7 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
break
end
topmost === nothing || continue
if edge_matches_sv(interp, infstate, method, sig, sparams, hardlimit, sv)
if edge_matches_sv(infstate, method, sig, sparams, hardlimit, sv)
topmost = infstate
edgecycle = true
end
Expand Down Expand Up @@ -677,13 +677,12 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
return MethodCallResult(rt, edgecycle, edgelimited, edge, effects)
end

function edge_matches_sv(interp::AbstractInterpreter, frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
function edge_matches_sv(frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
# The `method_for_inference_heuristics` will expand the given method's generator if
# necessary in order to retrieve this field from the generated `CodeInfo`, if it exists.
# The other `CodeInfo`s we inspect will already have this field inflated, so we just
# access it directly instead (to avoid regeneration).
world = get_world_counter(interp)
callee_method2 = method_for_inference_heuristics(method, sig, sparams, world) # Union{Method, Nothing}
callee_method2 = method_for_inference_heuristics(method, sig, sparams) # Union{Method, Nothing}

inf_method2 = frame.src.method_for_inference_limit_heuristics # limit only if user token match
inf_method2 isa Method || (inf_method2 = nothing)
Expand Down Expand Up @@ -720,11 +719,11 @@ function edge_matches_sv(interp::AbstractInterpreter, frame::InferenceState, met
end

# This function is used for computing alternate limit heuristics
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector, world::UInt)
if isdefined(method, :generator) && !(method.generator isa Core.GeneratedFunctionStub) && may_invoke_generator(method, sig, sparams)
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector)
if isdefined(method, :generator) && method.generator.expand_early && may_invoke_generator(method, sig, sparams)
method_instance = specialize_method(method, sig, sparams)
if isa(method_instance, MethodInstance)
cinfo = get_staged(method_instance, world)
cinfo = get_staged(method_instance)
if isa(cinfo, CodeInfo)
method2 = cinfo.method_for_inference_limit_heuristics
if method2 isa Method
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ let interp = NativeInterpreter()
else
tt = Tuple{typeof(f), Vararg{Any}}
end
for m in _methods_by_ftype(tt, 10, get_world_counter())::Vector
for m in _methods_by_ftype(tt, 10, typemax(UInt))::Vector
# remove any TypeVars from the intersection
m = m::MethodMatch
typ = Any[m.spec_types.parameters...]
Expand Down
3 changes: 1 addition & 2 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,7 @@ end

function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
# prepare an InferenceState object for inferring lambda
world = get_world_counter(interp)
src = retrieve_code_info(result.linfo, world)
src = retrieve_code_info(result.linfo)
src === nothing && return nothing
validate_code_in_debug_mode(result.linfo, src, "lowered")
return InferenceState(result, src, cache, interp)
Expand Down
3 changes: 1 addition & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::Optimiz
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, nothing, false)
end
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
world = get_world_counter(interp)
src = retrieve_code_info(linfo, world)
src = retrieve_code_info(linfo)
src === nothing && return nothing
return OptimizationState(linfo, src, params, interp)
end
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
end
end
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0 && !generating_sysimg()
return retrieve_code_info(mi, get_world_counter(interp))
return retrieve_code_info(mi)
end
lock_mi_inference(interp, mi)
result = InferenceResult(mi, typeinf_lattice(interp))
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ struct NativeInterpreter <: AbstractInterpreter
cache = Vector{InferenceResult}() # Initially empty cache

# Sometimes the caller is lazy and passes typemax(UInt).
# we cap it to the current world age for correctness
# we cap it to the current world age
if world == typemax(UInt)
world = get_world_counter()
end
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,23 @@ end
invoke_api(li::CodeInstance) = ccall(:jl_invoke_api, Cint, (Any,), li)
use_const_api(li::CodeInstance) = invoke_api(li) == 2

function get_staged(mi::MethodInstance, world::UInt)
function get_staged(mi::MethodInstance)
may_invoke_generator(mi) || return nothing
try
# user code might throw errors – ignore them
ci = ccall(:jl_code_for_staged, Any, (Any, UInt), mi, world)::CodeInfo
ci = ccall(:jl_code_for_staged, Any, (Any,), mi)::CodeInfo
return ci
catch
return nothing
end
end

function retrieve_code_info(linfo::MethodInstance, world::UInt)
function retrieve_code_info(linfo::MethodInstance)
m = linfo.def::Method
c = nothing
if isdefined(m, :generator)
# user code might throw errors – ignore them
c = get_staged(linfo, world)
c = get_staged(linfo)
end
if c === nothing && isdefined(m, :source)
src = m.source
Expand Down
7 changes: 4 additions & 3 deletions base/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,15 @@ end

"""
validate_code!(errors::Vector{InvalidCodeError}, mi::MethodInstance,
c::Union{Nothing,CodeInfo})
c::Union{Nothing,CodeInfo} = Core.Compiler.retrieve_code_info(mi))
Validate `mi`, logging any violation by pushing an `InvalidCodeError` into `errors`.
If `isa(c, CodeInfo)`, also call `validate_code!(errors, c)`. It is assumed that `c` is
a `CodeInfo` instance associated with `mi`.
the `CodeInfo` instance associated with `mi`.
"""
function validate_code!(errors::Vector{InvalidCodeError}, mi::Core.MethodInstance, c::Union{Nothing,CodeInfo})
function validate_code!(errors::Vector{InvalidCodeError}, mi::Core.MethodInstance,
c::Union{Nothing,CodeInfo} = Core.Compiler.retrieve_code_info(mi))
is_top_level = mi.def isa Module
if is_top_level
mnargs = 0
Expand Down
5 changes: 4 additions & 1 deletion base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,10 @@ macro generated(f)
Expr(:block,
lno,
Expr(:if, Expr(:generated),
body,
# https://github.com/JuliaLang/julia/issues/25678
Expr(:block,
:(local $tmp = $body),
:(if $tmp isa $(GlobalRef(Core, :CodeInfo)); return $tmp; else $tmp; end)),
Expr(:block,
Expr(:meta, :generated_only),
Expr(:return, nothing))))))
Expand Down
38 changes: 14 additions & 24 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -961,11 +961,10 @@ function code_lowered(@nospecialize(f), @nospecialize(t=Tuple); generated::Bool=
if debuginfo !== :source && debuginfo !== :none
throw(ArgumentError("'debuginfo' must be either :source or :none"))
end
world = get_world_counter()
return map(method_instances(f, t, world)) do m
return map(method_instances(f, t)) do m
if generated && hasgenerator(m)
if may_invoke_generator(m)
return ccall(:jl_code_for_staged, Any, (Any, UInt), m, world)::CodeInfo
return ccall(:jl_code_for_staged, Any, (Any,), m)::CodeInfo
else
error("Could not expand generator for `@generated` method ", m, ". ",
"This can happen if the provided argument types (", t, ") are ",
Expand Down Expand Up @@ -1054,8 +1053,6 @@ methods(@nospecialize(f), @nospecialize(t), mod::Module) = methods(f, t, (mod,))
function methods_including_ambiguous(@nospecialize(f), @nospecialize(t))
tt = signature_type(f, t)
world = get_world_counter()
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")
min = RefValue{UInt}(typemin(UInt))
max = RefValue{UInt}(typemax(UInt))
ms = _methods_by_ftype(tt, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))::Vector
Expand Down Expand Up @@ -1128,11 +1125,9 @@ _uncompressed_ir(ci::Core.CodeInstance, s::Array{UInt8,1}) = ccall(:jl_uncompres
const uncompressed_ast = uncompressed_ir
const _uncompressed_ast = _uncompressed_ir

function method_instances(@nospecialize(f), @nospecialize(t), world::UInt)
function method_instances(@nospecialize(f), @nospecialize(t), world::UInt=get_world_counter())
tt = signature_type(f, t)
results = Core.MethodInstance[]
# this make a better error message than the typeassert that follows
world == typemax(UInt) && error("code reflection cannot be used from generated functions")
for match in _methods_by_ftype(tt, -1, world)::Vector
instance = Core.Compiler.specialize_method(match)
push!(results, instance)
Expand Down Expand Up @@ -1203,22 +1198,20 @@ function may_invoke_generator(method::Method, @nospecialize(atype), sparams::Sim
# generator only has one method
generator = method.generator
isa(generator, Core.GeneratedFunctionStub) || return false
gen_mthds = _methods_by_ftype(Tuple{typeof(generator.gen), Vararg{Any}}, 1, method.primary_world)
(gen_mthds isa Vector && length(gen_mthds) == 1) || return false
gen_mthds = methods(generator.gen)::MethodList
length(gen_mthds) == 1 || return false

generator_method = first(gen_mthds).method
generator_method = first(gen_mthds)
nsparams = length(sparams)
isdefined(generator_method, :source) || return false
code = generator_method.source
nslots = ccall(:jl_ir_nslots, Int, (Any,), code)
at = unwrap_unionall(atype)
at isa DataType || return false
at = unwrap_unionall(atype)::DataType
(nslots >= 1 + length(sparams) + length(at.parameters)) || return false

firstarg = 1
for i = 1:nsparams
if isa(sparams[i], TypeVar)
if (ast_slotflag(code, firstarg + i) & SLOT_USED) != 0
if (ast_slotflag(code, 1 + i) & SLOT_USED) != 0
return false
end
end
Expand All @@ -1227,15 +1220,15 @@ function may_invoke_generator(method::Method, @nospecialize(atype), sparams::Sim
non_va_args = method.isva ? nargs - 1 : nargs
for i = 1:non_va_args
if !isdispatchelem(at.parameters[i])
if (ast_slotflag(code, firstarg + i + nsparams) & SLOT_USED) != 0
if (ast_slotflag(code, 1 + i + nsparams) & SLOT_USED) != 0
return false
end
end
end
if method.isva
# If the va argument is used, we need to ensure that all arguments that
# contribute to the va tuple are dispatchelemes
if (ast_slotflag(code, firstarg + nargs + nsparams) & SLOT_USED) != 0
if (ast_slotflag(code, 1 + nargs + nsparams) & SLOT_USED) != 0
for i = (non_va_args+1):length(at.parameters)
if !isdispatchelem(at.parameters[i])
return false
Expand Down Expand Up @@ -1325,8 +1318,7 @@ function code_typed_by_type(@nospecialize(tt::Type);
debuginfo::Symbol=:default,
world = get_world_counter(),
interp = Core.Compiler.NativeInterpreter(world))
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")
ccall(:jl_is_in_pure_context, Bool, ()) && error("code reflection cannot be used from generated functions")
if @isdefined(IRShow)
debuginfo = IRShow.debuginfo(debuginfo)
elseif debuginfo === :default
Expand Down Expand Up @@ -1435,7 +1427,7 @@ function code_ircode_by_type(
interp = Core.Compiler.NativeInterpreter(world),
optimize_until::Union{Integer,AbstractString,Nothing} = nothing,
)
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
ccall(:jl_is_in_pure_context, Bool, ()) &&
error("code reflection cannot be used from generated functions")
tt = to_tuple_type(tt)
matches = _methods_by_ftype(tt, -1, world)::Vector
Expand All @@ -1462,8 +1454,7 @@ end
function return_types(@nospecialize(f), @nospecialize(types=default_tt(f));
world = get_world_counter(),
interp = Core.Compiler.NativeInterpreter(world))
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")
ccall(:jl_is_in_pure_context, Bool, ()) && error("code reflection cannot be used from generated functions")
if isa(f, Core.OpaqueClosure)
_, rt = only(code_typed_opaque_closure(f))
return Any[rt]
Expand All @@ -1487,8 +1478,7 @@ end
function infer_effects(@nospecialize(f), @nospecialize(types=default_tt(f));
world = get_world_counter(),
interp = Core.Compiler.NativeInterpreter(world))
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")
ccall(:jl_is_in_pure_context, Bool, ()) && error("code reflection cannot be used from generated functions")
if isa(f, Core.Builtin)
types = to_tuple_type(types)
argtypes = Any[Core.Compiler.Const(f), types.parameters...]
Expand Down
8 changes: 4 additions & 4 deletions doc/src/devdocs/ast.md
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,10 @@ A (usually temporary) container for holding lowered source code.

A `UInt8` array of slot properties, represented as bit flags:

* 0x02 - assigned (only false if there are *no* assignment statements with this var on the left)
* 0x08 - used (if there is any read or write of the slot)
* 0x10 - statically assigned once
* 0x20 - might be used before assigned. This flag is only valid after type inference.
* 2 - assigned (only false if there are *no* assignment statements with this var on the left)
* 8 - const (currently unused for local variables)
* 16 - statically assigned once
* 32 - might be used before assigned. This flag is only valid after type inference.

* `ssavaluetypes`

Expand Down
2 changes: 1 addition & 1 deletion src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, siz
if (src)
jlrettype = src->rettype;
else if (jl_is_method(mi->def.method)) {
src = mi->def.method->generator ? jl_code_for_staged(mi, world) : (jl_code_info_t*)mi->def.method->source;
src = mi->def.method->generator ? jl_code_for_staged(mi) : (jl_code_info_t*)mi->def.method->source;
if (src && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ir(mi->def.method, NULL, (jl_array_t*)src);
}
Expand Down
4 changes: 2 additions & 2 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -1024,10 +1024,10 @@ static jl_value_t *jl_invoke_julia_macro(jl_array_t *args, jl_module_t *inmodule
jl_value_t *result;
JL_TRY {
margs[0] = jl_toplevel_eval(*ctx, margs[0]);
jl_method_instance_t *mfunc = jl_method_lookup(margs, nargs, ct->world_age);
jl_method_instance_t *mfunc = jl_method_lookup(margs, nargs, world);
JL_GC_PROMISE_ROOTED(mfunc);
if (mfunc == NULL) {
jl_method_error(margs[0], &margs[1], nargs, ct->world_age);
jl_method_error(margs[0], &margs[1], nargs, world);
// unreachable
}
*ctx = mfunc->def.method->module;
Expand Down
7 changes: 1 addition & 6 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ extern "C" {
JL_DLLEXPORT _Atomic(size_t) jl_world_counter = 1; // uses atomic acquire/release
JL_DLLEXPORT size_t jl_get_world_counter(void) JL_NOTSAFEPOINT
{
jl_task_t *ct = jl_current_task;
if (ct->ptls->in_pure_callback)
return ~(size_t)0;
return jl_atomic_load_acquire(&jl_world_counter);
}

Expand Down Expand Up @@ -2299,7 +2296,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
// if that didn't work and compilation is off, try running in the interpreter
if (compile_option == JL_OPTIONS_COMPILE_OFF ||
compile_option == JL_OPTIONS_COMPILE_MIN) {
jl_code_info_t *src = jl_code_for_interpreter(mi, world);
jl_code_info_t *src = jl_code_for_interpreter(mi);
if (!jl_code_requires_compiler(src, 0)) {
jl_code_instance_t *codeinst = jl_new_codeinst(mi,
(jl_value_t*)jl_any_type, NULL, NULL,
Expand Down Expand Up @@ -3142,8 +3139,6 @@ static jl_value_t *ml_matches(jl_methtable_t *mt,
int intersections, size_t world, int cache_result,
size_t *min_valid, size_t *max_valid, int *ambig)
{
if (world > jl_atomic_load_acquire(&jl_world_counter))
return jl_nothing; // the future is not enumerable
int has_ambiguity = 0;
jl_value_t *unw = jl_unwrap_unionall((jl_value_t*)type);
assert(jl_is_datatype(unw));
Expand Down
Loading

0 comments on commit 7823552

Please sign in to comment.