Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: parameterize some of hard-coded inference logic #39439

Merged
merged 1 commit into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 40 additions & 21 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
napplicable = length(applicable)
rettype = Bottom
edgecycle = false
edges = Any[]
edges = MethodInstance[]
nonbot = 0 # the index of the only non-Bottom inference result if > 0
seen = 0 # number of signatures actually inferred
istoplevel = sv.linfo.def isa Module
multiple_matches = napplicable > 1

if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
Expand All @@ -115,7 +114,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
match = applicable[i]::MethodMatch
method = match.method
sig = match.spec_types
if istoplevel && !isdispatchtuple(sig)
if bail_out_toplevel_call(interp, sig, sv)
# only infer concrete call sites in top-level expressions
add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression")
rettype = Any
Expand All @@ -135,7 +134,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
edgecycle |= edgecycle1::Bool
this_rt = tmerge(this_rt, rt)
this_rt === Any && break
if bail_out_call(interp, this_rt, sv)
break
end
end
else
this_rt, edgecycle1, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv)
Expand All @@ -153,7 +154,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
seen += 1
rettype = tmerge(rettype, this_rt)
rettype === Any && break
if bail_out_call(interp, rettype, sv)
break
end
end
# try constant propagation if only 1 method is inferred to non-Bottom
# this is in preparation for inlining, or improving the return result
Expand All @@ -179,18 +182,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# and avoid keeping track of a more complex result type.
rettype = Any
end
if !(rettype === Any) # adding a new method couldn't refine (widen) this type
for edge in edges
add_backedge!(edge::MethodInstance, sv)
end
for (thisfullmatch, mt) in zip(fullmatch, mts)
if !thisfullmatch
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge!(mt, atype, sv)
end
end
end
add_call_backedges!(interp, rettype, edges, fullmatch, mts, atype, sv)
#print("=> ", rettype, "\n")
if rettype isa LimitedAccuracy
union!(sv.pclimitations, rettype.causes)
Expand All @@ -205,6 +197,27 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(rettype, info)
end

function add_call_backedges!(interp::AbstractInterpreter,
@nospecialize(rettype),
edges::Vector{MethodInstance},
fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype),
sv::InferenceState)
if rettype === Any
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine
# (widen) this type
return
end
for edge in edges
add_backedge!(edge, sv)
end
for (thisfullmatch, mt) in zip(fullmatch, mts)
if !thisfullmatch
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge!(mt, atype, sv)
end
end
end

function const_prop_profitable(@nospecialize(arg))
# have new information from argtypes that wasn't available from the signature
Expand Down Expand Up @@ -746,7 +759,7 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
call = abstract_call(interp, nothing, ct, sv, max_methods)
push!(retinfos, ApplyCallInfo(call.info, arginfo))
res = tmerge(res, call.rt)
if res === Any
if bail_out_apply(interp, res, sv)
# No point carrying forward the info, we're not gonna inline it anyway
retinfo = nothing
break
Expand Down Expand Up @@ -1171,7 +1184,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
argtypes = Vector{Any}(undef, n)
@inbounds for i = 1:n
ai = abstract_eval_value(interp, ea[i], vtypes, sv)
if ai === Bottom
if bail_out_statement(interp, ai, sv)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon further reflection, I'm not sure about this one. The return Bottom here kinda implies that the only thing this could possibly do is to check for a literal Bottom, no? Maybe it's a similar situation where the bail check needs to be separate from the check for Bottom.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think this could be useful for code analysis (but not in terms of optimization).

Say if we analyze this kind of code below:

function foo(a)
    a = a * missing # => NoMethodError
    sin_broad(a)
end

sin_broad(a) = sinn.(a) # => UndefVarError

foo([1,2,3]) # entry point

I believe it's better to report two errors in this case (i.e. NoMethodError(*, (Vector{Int}, Missing)), UndefVarError(:sinn))).
With the bail_out_statement and bail_out_local interface, we can do something like below:

import Core.Compiler:
    bail_out_local,
    bail_out_statement,
    abstract_call

bail_out_local(interp::JETInterpreter, @nospecialize(t), sv) = false
bail_out_statement(interp::JETInterpreter, @nospecialize(t), sv) = false

function CC.abstract_call(interp::JETInterpreter, ea::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
                          sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
    argtypes = Any[a === Bottom ? Any : a for a in argtypes] # assuming the previous error is fixed, keep analysis going
    @invoke abstract_call(interp::AbstractInterpreter, ea::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
                          sv::InferenceState, max_methods::Int)
end

and keep analysis going and collect as much errors as possible even when we know the actual execution terminates earlier.

I agree with this interface would be weird from optimization perspective, though.

/cc @vtjnash because you added commit to revert these interfaces in #39606. Could I hear your ideas on this ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I thought you wanted to use these to early-out in useless analysis rather than to continue into unreachable code-paths. I think doing the latter can be pretty tricky.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those aren't the same conditions though. If you did want to bail early, it should be using the Union{} lattice element (otherwise we may exit before the lattice has finished converging). If you want to bail late, that decision doesn't happen here either (if later tfuncs are correct, they should also return Union{} in this case, but it would be a wild mess add this check everywhere). That decision happens in typeinf_local, where we decide to stop forwarding information to the next statement rather than, say, attempting to set changes = fill!(similar(changes), VarState(Union{}, true)) so that we reach each subsequent statement (Though most will also error. Since this is dead/unreachable code, that is what we want, to avoid poisoning the actual inference results at the next phi node join point.)

I agree you could replace Union{} with Any (not here, but when they are accessed from the Slot), but that might quite quickly confuse the convergence algorithm and drive many more function towards being uninferrable. To avoid that, I think you'd realistically need to run inference repeatedly, "fixing" one error at a time from the outside driver.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I strongly dislike parametrization just for the sake of parameterization, since it bloats the compiler and makes it harder to follow the logical flow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your responses.
Fair. As you pointed out my idea will be really complicated otherwise it will ruin inference results and convergence.
Even when I try that in the future, I think I will end up overloading the entire body of typeinf_local anyway, so honestly I'm not too stick with bail_out_local/statement interface and I'm okay with #39606.

In general, I strongly dislike parametrization just for the sake of parameterization, since it bloats the compiler and makes it harder to follow the logical flow.

Yeah, that's what I was worried about. Are you okay with bail_out_call/bail_out_apply ? I'm fairly sure they won't destroy the convergence nor inference results anyway, but still might look weird in term of the implementation of NativeInterpreter. Still they are very useful for the implementation of JETInterpreter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those seem less bad, since it is probably just a performance optimization. There's also possibly not really any other correct behavior there to put in the overload either? We already computed the final call graph (info) and return type (Any), and no further analysis work should alter those. Thus also where we see how the optimizer is designed to run as a separated pass, and consumes both of those bits of information later and avoids being dependent upon the precise way it arrived at the answer.

return Bottom
end
argtypes[i] = ai
Expand Down Expand Up @@ -1349,6 +1362,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
condt = abstract_eval_value(interp, stmt.cond, s[pc], frame)
if condt === Bottom
empty!(frame.pclimitations)
end
if bail_out_local(interp, condt, frame)
break
end
condval = maybe_extract_const_bool(condt)
Expand Down Expand Up @@ -1440,7 +1455,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
else
if hd === :(=)
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
t === Bottom && break
if bail_out_local(interp, t, frame)
break
end
frame.src.ssavaluetypes[pc] = t
lhs = stmt.args[1]
if isa(lhs, Slot)
Expand All @@ -1455,7 +1472,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# these do not generate code
else
t = abstract_eval_statement(interp, stmt, changes, frame)
t === Bottom && break
if bail_out_local(interp, t, frame)
break
end
if !isempty(frame.ssavalue_uses[pc])
record_ssa_assign(pc, t, frame)
else
Expand Down
13 changes: 13 additions & 0 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,16 @@ may_compress(ni::NativeInterpreter) = true
may_discard_trees(ni::NativeInterpreter) = true

method_table(ai::AbstractInterpreter) = InternalMethodTable(get_world_counter(ai))

# define inference bail out logic
# `NativeInterpreter` bails out from inference when
# - a lattice element grows up to `Any` (inter-procedural call, abstract apply)
# - a lattice element gets down to `Bottom` (statement inference, local frame inference)
# - inferring non-concrete toplevel call sites
bail_out_call(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Any
bail_out_apply(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Any
bail_out_statement(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Bottom
bail_out_local(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Bottom
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between bail_out_statement and bail_out_local?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bail_out_local controls the frame-level bail out logic in typeinf_local while bail_out_statement does the statement-level bail out logic in abstract_call_statement.
The reason because I separate them is that it gives us more fine-grained control on the heuristic and we may want to use it.

function bail_out_toplevel_call(interp::AbstractInterpreter, @nospecialize(sig), sv)
return isa(sv.linfo.def, Module) && !isdispatchtuple(sig)
end