Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: enable CodeInfo method_for_inference_limit_heuristics support #26822

Merged
merged 1 commit into from
Apr 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
cyclei = 0
infstate = sv
edgecycle = false
method2 = method_for_inference_heuristics(method, sig, sparams, sv.params.world) # Union{Method, Nothing}
while !(infstate === nothing)
infstate = infstate::InferenceState
if method === infstate.linfo.def
Expand All @@ -197,7 +198,9 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
edgecycle = true
break
end
if topmost === nothing
inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
if topmost === nothing && method2 === inf_method2
# inspect the parent of this edge,
# to see if they are the same Method as sv
# in which case we'll need to ensure it is convergent
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool,
method = linfo.def::Method
tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ]
tree.signature_for_inference_heuristics = nothing
tree.method_for_inference_limit_heuristics = nothing
tree.slotnames = Any[ COMPILER_TEMP_SYM for i = 1:method.nargs ]
tree.slotflags = UInt8[ 0 for i = 1:method.nargs ]
tree.slottypes = nothing
Expand Down
32 changes: 10 additions & 22 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,33 +155,21 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV
return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any, UInt), method, atypes, sparams, world)
end

# TODO: Use these functions instead of directly manipulating
# the "actual" method for appropriate places in inference (see #24676)
function method_for_inference_heuristics(cinfo, default)
if isa(cinfo, CodeInfo)
# appropriate format for `sig` is svec(ftype, argtypes, world)
sig = cinfo.signature_for_inference_heuristics
if isa(sig, SimpleVector) && length(sig) == 3
methods = _methods(sig[1], sig[2], -1, sig[3])
if length(methods) == 1
_, _, m = methods[]
if isa(m, Method)
return m
end
end
end
end
return default
end

function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world)
# This function is used for computing alternate limit heuristics
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector, world::UInt)
if isdefined(method, :generator) && method.generator.expand_early
method_instance = code_for_method(method, sig, sparams, world, false)
if isa(method_instance, MethodInstance)
return method_for_inference_heuristics(get_staged(method_instance), method)
cinfo = get_staged(method_instance)
if isa(cinfo, CodeInfo)
method2 = cinfo.method_for_inference_limit_heuristics
if method2 isa Method
return method2
end
end
end
end
return method
return nothing
end

function exprtype(@nospecialize(x), src, mod::Module)
Expand Down
3 changes: 2 additions & 1 deletion src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -2300,7 +2300,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)

size_t nf = jl_datatype_nfields(jl_code_info_type);
for (i = 0; i < nf - 5; i++) {
jl_serialize_value_(&s, jl_get_nth_field((jl_value_t*)code, i), 1);
int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field
jl_serialize_value_(&s, jl_get_nth_field((jl_value_t*)code, i), copy);
}

ios_putc('\0', s.s);
Expand Down
2 changes: 1 addition & 1 deletion src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2026,7 +2026,7 @@ void jl_init_types(void)
jl_perm_symsvec(12,
"code",
"codelocs",
"signature_for_inference_heuristics",
"method_for_inference_limit_heuristics",
"slottypes",
"ssavaluetypes",
"linetable",
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ typedef struct _jl_llvm_functions_t {
typedef struct _jl_code_info_t {
jl_array_t *code; // Any array of statements
jl_value_t *codelocs; // Int array of indicies into the line table
jl_value_t *signature_for_inference_heuristics; // optional method used during inference
jl_value_t *method_for_inference_limit_heuristics; // optional method used during inference
jl_value_t *slottypes; // types of variable slots (or `nothing`)
jl_value_t *ssavaluetypes; // types of ssa values (or count of them)
jl_value_t *linetable; // Table of locations
Expand Down
4 changes: 2 additions & 2 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast)
jl_array_del_end(meta, na - ins);
}
}
li->signature_for_inference_heuristics = jl_nothing;
li->method_for_inference_limit_heuristics = jl_nothing;
jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1);
jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0);
size_t nslots = jl_array_len(vis);
Expand Down Expand Up @@ -303,7 +303,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
(jl_code_info_t*)jl_gc_alloc(ptls, sizeof(jl_code_info_t),
jl_code_info_type);
src->code = NULL;
src->signature_for_inference_heuristics = NULL;
src->method_for_inference_limit_heuristics = NULL;
src->slotnames = NULL;
src->slotflags = NULL;
src->slottypes = NULL;
Expand Down
2 changes: 1 addition & 1 deletion src/toplevel.c
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ static jl_code_info_t *expr_to_code_info(jl_value_t *expr)
jl_gc_wb(src, src->slotflags);
src->ssavaluetypes = jl_box_long(0);
jl_gc_wb(src, src->ssavaluetypes);
src->signature_for_inference_heuristics = jl_nothing;
src->method_for_inference_limit_heuristics = jl_nothing;
src->codelocs = jl_nothing;
src->linetable = jl_nothing;

Expand Down
112 changes: 79 additions & 33 deletions test/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1310,73 +1310,119 @@ function _generated_stub(gen::Symbol, args::Vector{Any}, params::Vector{Any}, li
return Expr(:meta, :generated, stub)
end

f24852_kernel(x, y) = x * y

function f24852_kernel_cinfo(x, y)
sig, spvals, method = Base._methods_by_ftype(Tuple{typeof(f24852_kernel),x,y}, -1, typemax(UInt))[1]
f24852_kernel1(x, y::Tuple) = x * y[1][1][1]
f24852_kernel2(x, y::Tuple) = f24852_kernel1(x, (y,))
f24852_kernel3(x, y::Tuple) = f24852_kernel2(x, (y,))
f24852_kernel(x, y::Number) = f24852_kernel3(x, (y,))

function f24852_kernel_cinfo(fsig::Type)
world = typemax(UInt) # FIXME
sig, spvals, method = Base._methods_by_ftype(fsig, -1, world)[1]
isdefined(method, :source) || return (nothing, :(f(x, y)))
code_info = Base.uncompressed_ast(method)
body = Expr(:block, code_info.code...)
Base.Core.Compiler.substitute!(body, 0, Any[], sig, Any[spvals...], 0, :propagate)
Base.Core.Compiler.substitute!(body, 0, Any[], sig, Any[spvals...], 1, :propagate)
if startswith(String(method.name), "f24852")
for a in body.args
if a isa Expr && a.head == :(=)
a = a.args[2]
end
if a isa Expr && length(a.args) === 3 && a.head === :call
pushfirst!(a.args, Core.SlotNumber(1))
end
end
end
pushfirst!(code_info.slotnames, Symbol("#self#"))
pushfirst!(code_info.slotflags, 0x00)
return method, code_info
end

function f24852_gen_cinfo_uninflated(X, Y, f, x, y)
_, code_info = f24852_kernel_cinfo(x, y)
function f24852_gen_cinfo_uninflated(X, Y, _, f, x, y)
_, code_info = f24852_kernel_cinfo(Tuple{f, x, y})
return code_info
end

function f24852_gen_cinfo_inflated(X, Y, f, x, y)
method, code_info = f24852_kernel_cinfo(x, y)
code_info.signature_for_inference_heuristics = Core.Compiler.svec(f, (x, y), typemax(UInt))
function f24852_gen_cinfo_inflated(X, Y, _, f, x, y)
method, code_info = f24852_kernel_cinfo(Tuple{f, x, y})
code_info.method_for_inference_limit_heuristics = method
return code_info
end

function f24852_gen_expr(X, Y, f, x, y)
return :(f24852_kernel(x::$X, y::$Y))
function f24852_gen_expr(X, Y, _, f, x, y) # deparse f(x::X, y::Y) where {X, Y}
if f === typeof(f24852_kernel)
f2 = :f24852_kernel3
elseif f === typeof(f24852_kernel3)
f2 = :f24852_kernel2
elseif f === typeof(f24852_kernel2)
f2 = :f24852_kernel1
elseif f === typeof(f24852_kernel1)
return :((x::$X) * (y::$Y)[1][1][1])
else
return :(error(repr(f)))
end
return :(f24852_late_expr($f2, x::$X, (y::$Y,)))
end

@eval begin
function f24852_late_expr(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:f24852_late_expr, :x, :y],
function f24852_late_expr(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
$(Expr(:meta, :generated_only))
#= no body =#
end
function f24852_late_inflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_late_inflated, :x, :y],
function f24852_late_inflated(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
$(Expr(:meta, :generated_only))
#= no body =#
end
function f24852_late_uninflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_late_uninflated, :x, :y],
function f24852_late_uninflated(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
$(Expr(:meta, :generated_only))
#= no body =#
end
end

@eval begin
function f24852_early_expr(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:f24852_early_expr, :x, :y],
function f24852_early_expr(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
$(Expr(:meta, :generated_only))
#= no body =#
end
function f24852_early_inflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_early_inflated, :x, :y],
function f24852_early_inflated(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
$(Expr(:meta, :generated_only))
#= no body =#
end
function f24852_early_uninflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_early_uninflated, :x, :y],
function f24852_early_uninflated(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
$(Expr(:meta, :generated_only))
#= no body =#
end
end

x, y = rand(), rand()
result = f24852_kernel(x, y)

@test result === f24852_late_expr(x, y)
@test result === f24852_late_uninflated(x, y)
@test result === f24852_late_inflated(x, y)

@test result === f24852_early_expr(x, y)
@test result === f24852_early_uninflated(x, y)
@test result === f24852_early_inflated(x, y)

# TODO: test that `expand_early = true` + inflated `signature_for_inference_heuristics`
@test result === f24852_late_expr(f24852_kernel, x, y)
@test Base.return_types(f24852_late_expr, typeof((f24852_kernel, x, y))) == Any[Any]
@test result === f24852_late_uninflated(f24852_kernel, x, y)
@test Base.return_types(f24852_late_uninflated, typeof((f24852_kernel, x, y))) == Any[Any]
@test result === f24852_late_uninflated(f24852_kernel, x, y)
@test Base.return_types(f24852_late_uninflated, typeof((f24852_kernel, x, y))) == Any[Any]

@test result === f24852_early_expr(f24852_kernel, x, y)
@test Base.return_types(f24852_early_expr, typeof((f24852_kernel, x, y))) == Any[Any]
@test result === f24852_early_uninflated(f24852_kernel, x, y)
@test Base.return_types(f24852_early_uninflated, typeof((f24852_kernel, x, y))) == Any[Any]
@test result === @inferred f24852_early_inflated(f24852_kernel, x, y)
@test Base.return_types(f24852_early_inflated, typeof((f24852_kernel, x, y))) == Any[Float64]

# TODO: test that `expand_early = true` + inflated `method_for_inference_limit_heuristics`
# can be used to tighten up some inference result.

# Test that Conditional doesn't get widened to Bool too quickly
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/ssair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ let code = Any[
))

Compiler.run_passes(ci, 1, Compiler.LineInfoNode[Compiler.NullLineInfo])
end
end