Skip to content

Commit

Permalink
Merge pull request #45299 from JuliaLang/kf/rt_effect_free
Browse files Browse the repository at this point in the history
Fix effects modeling for return_type
  • Loading branch information
Keno authored May 16, 2022
2 parents f9aa28f + aad9ff6 commit 37dd084
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 74 deletions.
114 changes: 58 additions & 56 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,16 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# At this point we are guaranteed to end up throwing on this path,
# which is all that's required for :consistent-cy. Of course, we don't
# know anything else about this statement.
tristate_merge!(sv, Effects(; consistent=ALWAYS_TRUE, nonoverlayed))
return CallMeta(Any, false)
effects = Effects(; consistent=ALWAYS_TRUE, nonoverlayed)
return CallMeta(Any, effects, false)
end

argtypes = arginfo.argtypes
matches = find_matching_methods(argtypes, atype, method_table(interp),
InferenceParams(interp).MAX_UNION_SPLITTING, max_methods)
if isa(matches, FailedMethodMatch)
add_remark!(interp, sv, matches.reason)
tristate_merge!(sv, Effects())
return CallMeta(Any, false)
return CallMeta(Any, Effects(), false)
end

(; valid_worlds, applicable, info) = matches
Expand All @@ -97,7 +96,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),

# try pure-evaluation
val = pure_eval_call(interp, f, applicable, arginfo, sv)
val !== nothing && return CallMeta(val, MethodResultPure(info)) # TODO: add some sort of edge(s)
val !== nothing && return CallMeta(val, all_effects, MethodResultPure(info)) # TODO: add some sort of edge(s)

for i in 1:napplicable
match = applicable[i]::MethodMatch
Expand Down Expand Up @@ -240,8 +239,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
delete!(sv.pclimitations, caller)
end
end
tristate_merge!(sv, all_effects)
return CallMeta(rettype, info)
return CallMeta(rettype, all_effects, info)
end

struct FailedMethodMatch
Expand Down Expand Up @@ -1193,7 +1191,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# WARNING: Changes to the iteration protocol must be reflected here,
# this is not just an optimization.
# TODO: this doesn't realize that Array, SimpleVector, Tuple, and NamedTuple do not use the iterate protocol
stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, info)])
stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, call.effects, info)])
valtype = statetype = Bottom
ret = Any[]
calls = CallMeta[call]
Expand Down Expand Up @@ -1269,23 +1267,23 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::
max_methods::Int = get_max_methods(sv.mod, interp))
itft = argtype_by_index(argtypes, 2)
aft = argtype_by_index(argtypes, 3)
(itft === Bottom || aft === Bottom) && return CallMeta(Bottom, false)
(itft === Bottom || aft === Bottom) && return CallMeta(Bottom, EFFECTS_THROWS, false)
aargtypes = argtype_tail(argtypes, 4)
aftw = widenconst(aft)
if !isa(aft, Const) && !isa(aft, PartialOpaque) && (!isType(aftw) || has_free_typevars(aftw))
if !isconcretetype(aftw) || (aftw <: Builtin)
add_remark!(interp, sv, "Core._apply_iterate called on a function of a non-concrete type")
tristate_merge!(sv, Effects())
# bail now, since it seems unlikely that abstract_call will be able to do any better after splitting
# this also ensures we don't call abstract_call_gf_by_type below on an IntrinsicFunction or Builtin
return CallMeta(Any, false)
return CallMeta(Any, Effects(), false)
end
end
res = Union{}
nargs = length(aargtypes)
splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM
ctypes = [Any[aft]]
infos = Vector{MaybeAbstractIterationInfo}[MaybeAbstractIterationInfo[]]
effects = EFFECTS_TOTAL
for i = 1:nargs
ctypes´ = Vector{Any}[]
infos′ = Vector{MaybeAbstractIterationInfo}[]
Expand Down Expand Up @@ -1348,6 +1346,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::
call = abstract_call(interp, ArgInfo(nothing, ct), sv, max_methods)
push!(retinfos, ApplyCallInfo(call.info, arginfo))
res = tmerge(res, call.rt)
effects = tristate_merge(effects, call.effects)
if bail_out_apply(interp, res, sv)
if i != length(ctypes)
# No point carrying forward the info, we're not gonna inline it anyway
Expand All @@ -1358,7 +1357,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::
end
# TODO: Add a special info type to capture all the iteration info.
# For now, only propagate info if we don't also union-split the iteration
return CallMeta(res, retinfo)
return CallMeta(res, effects, retinfo)
end

function argtype_by_index(argtypes::Vector{Any}, i::Int)
Expand Down Expand Up @@ -1539,21 +1538,21 @@ end
function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, sv::InferenceState)
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bottom, false), EFFECTS_THROWS
ft === Bottom && return CallMeta(Bottom, EFFECTS_THROWS, false)
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3))
types === Bottom && return CallMeta(Bottom, false), EFFECTS_THROWS
isexact || return CallMeta(Any, false), Effects()
types === Bottom && return CallMeta(Bottom, EFFECTS_THROWS, false)
isexact || return CallMeta(Any, Effects(), false)
argtype = argtypes_to_type(argtype_tail(argtypes, 4))
nargtype = typeintersect(types, argtype)
nargtype === Bottom && return CallMeta(Bottom, false), EFFECTS_THROWS
nargtype isa DataType || return CallMeta(Any, false), Effects() # other cases are not implemented below
isdispatchelem(ft) || return CallMeta(Any, false), Effects() # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
nargtype === Bottom && return CallMeta(Bottom, EFFECTS_THROWS, false)
nargtype isa DataType || return CallMeta(Any, Effects(), false) # other cases are not implemented below
isdispatchelem(ft) || return CallMeta(Any, Effects(), false) # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
ft = ft::DataType
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)::Type
nargtype = Tuple{ft, nargtype.parameters...}
argtype = Tuple{ft, argtype.parameters...}
match, valid_worlds, overlayed = findsup(types, method_table(interp))
match === nothing && return CallMeta(Any, false), Effects()
match === nothing && return CallMeta(Any, Effects(), false)
update_valid_age!(sv, valid_worlds)
method = match.method
(ti, env::SimpleVector) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector
Expand All @@ -1580,7 +1579,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
end
end
effects = Effects(effects; nonoverlayed=!overlayed)
return CallMeta(from_interprocedural!(rt, sv, arginfo, sig), InvokeCallInfo(match, const_result)), effects
return CallMeta(from_interprocedural!(rt, sv, arginfo, sig), effects, InvokeCallInfo(match, const_result))
end

function invoke_rewrite(xs::Vector{Any})
Expand All @@ -1601,37 +1600,30 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
if f === _apply_iterate
return abstract_apply(interp, argtypes, sv, max_methods)
elseif f === invoke
call, effects = abstract_invoke(interp, arginfo, sv)
tristate_merge!(sv, effects)
return call
return abstract_invoke(interp, arginfo, sv)
elseif f === modifyfield!
tristate_merge!(sv, Effects()) # TODO
return abstract_modifyfield!(interp, argtypes, sv)
end
rt = abstract_call_builtin(interp, f, arginfo, sv, max_methods)
tristate_merge!(sv, builtin_effects(f, argtypes, rt))
return CallMeta(rt, false)
return CallMeta(rt, builtin_effects(f, argtypes, rt), false)
elseif isa(f, Core.OpaqueClosure)
# calling an OpaqueClosure about which we have no information returns no information
tristate_merge!(sv, Effects())
return CallMeta(Any, false)
return CallMeta(Any, Effects(), false)
elseif f === Core.kwfunc
if la == 2
aty = argtypes[2]
if !isvarargtype(aty)
ft = widenconst(aty)
if isa(ft, DataType) && isdefined(ft.name, :mt) && isdefined(ft.name.mt, :kwsorter)
return CallMeta(Const(ft.name.mt.kwsorter), MethodResultPure())
return CallMeta(Const(ft.name.mt.kwsorter), EFFECTS_TOTAL, MethodResultPure())
end
end
end
tristate_merge!(sv, EFFECTS_UNKNOWN) # TODO
return CallMeta(Any, false)
return CallMeta(Any, EFFECTS_UNKNOWN, false)
elseif f === TypeVar
# Manually look through the definition of TypeVar to
# make sure to be able to get `PartialTypeVar`s out.
tristate_merge!(sv, EFFECTS_UNKNOWN) # TODO
(la < 2 || la > 4) && return CallMeta(Union{}, false)
(la < 2 || la > 4) && return CallMeta(Union{}, EFFECTS_UNKNOWN, false)
n = argtypes[2]
ub_var = Const(Any)
lb_var = Const(Union{})
Expand All @@ -1641,36 +1633,33 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
elseif la == 3
ub_var = argtypes[3]
end
return CallMeta(typevar_tfunc(n, lb_var, ub_var), false)
return CallMeta(typevar_tfunc(n, lb_var, ub_var), EFFECTS_UNKNOWN, false)
elseif f === UnionAll
tristate_merge!(sv, EFFECTS_UNKNOWN) # TODO
return CallMeta(abstract_call_unionall(argtypes), false)
return CallMeta(abstract_call_unionall(argtypes), EFFECTS_UNKNOWN, false)
elseif f === Tuple && la == 2
tristate_merge!(sv, EFFECTS_UNKNOWN) # TODO
aty = argtypes[2]
ty = isvarargtype(aty) ? unwrapva(aty) : widenconst(aty)
if !isconcretetype(ty)
return CallMeta(Tuple, false)
return CallMeta(Tuple, EFFECTS_UNKNOWN, false)
end
elseif is_return_type(f)
tristate_merge!(sv, EFFECTS_UNKNOWN) # TODO
return return_type_tfunc(interp, argtypes, sv)
elseif la == 2 && istopfunction(f, :!)
# handle Conditional propagation through !Bool
aty = argtypes[2]
if isa(aty, Conditional)
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), Tuple{typeof(f), Bool}, sv, max_methods) # make sure we've inferred `!(::Bool)`
return CallMeta(Conditional(aty.var, aty.elsetype, aty.vtype), call.info)
return CallMeta(Conditional(aty.var, aty.elsetype, aty.vtype), call.effects, call.info)
end
elseif la == 3 && istopfunction(f, :!==)
# mark !== as exactly a negated call to ===
rty = abstract_call_known(interp, (===), arginfo, sv, max_methods).rt
if isa(rty, Conditional)
return CallMeta(Conditional(rty.var, rty.elsetype, rty.vtype), false) # swap if-else
return CallMeta(Conditional(rty.var, rty.elsetype, rty.vtype), EFFECTS_TOTAL, false) # swap if-else
elseif isa(rty, Const)
return CallMeta(Const(rty.val === false), MethodResultPure())
return CallMeta(Const(rty.val === false), EFFECTS_TOTAL, MethodResultPure())
end
return CallMeta(rty, false)
return CallMeta(rty, EFFECTS_TOTAL, false)
elseif la == 3 && istopfunction(f, :(>:))
# mark issupertype as a exact alias for issubtype
# swap T1 and T2 arguments and call <:
Expand All @@ -1680,26 +1669,26 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
fargs = nothing
end
argtypes = Any[typeof(<:), argtypes[3], argtypes[2]]
return CallMeta(abstract_call_known(interp, <:, ArgInfo(fargs, argtypes), sv, max_methods).rt, false)
return CallMeta(abstract_call_known(interp, <:, ArgInfo(fargs, argtypes), sv, max_methods).rt, EFFECTS_TOTAL, false)
elseif la == 2 &&
(a2 = argtypes[2]; isa(a2, Const)) && (svecval = a2.val; isa(svecval, SimpleVector)) &&
istopfunction(f, :length)
# mark length(::SimpleVector) as @pure
return CallMeta(Const(length(svecval)), MethodResultPure())
return CallMeta(Const(length(svecval)), EFFECTS_TOTAL, MethodResultPure())
elseif la == 3 &&
(a2 = argtypes[2]; isa(a2, Const)) && (svecval = a2.val; isa(svecval, SimpleVector)) &&
(a3 = argtypes[3]; isa(a3, Const)) && (idx = a3.val; isa(idx, Int)) &&
istopfunction(f, :getindex)
# mark getindex(::SimpleVector, i::Int) as @pure
if 1 <= idx <= length(svecval) && isassigned(svecval, idx)
return CallMeta(Const(getindex(svecval, idx)), MethodResultPure())
return CallMeta(Const(getindex(svecval, idx)), EFFECTS_TOTAL, MethodResultPure())
end
elseif la == 2 && istopfunction(f, :typename)
return CallMeta(typename_static(argtypes[2]), MethodResultPure())
return CallMeta(typename_static(argtypes[2]), EFFECTS_TOTAL, MethodResultPure())
elseif la == 3 && istopfunction(f, :typejoin)
if is_all_const_arg(arginfo)
val = _pure_eval_call(f, arginfo)
return CallMeta(val === nothing ? Type : val, MethodResultPure())
return CallMeta(val === nothing ? Type : val, EFFECTS_TOTAL, MethodResultPure())
end
end
atype = argtypes_to_type(argtypes)
Expand All @@ -1708,7 +1697,7 @@ end

function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, arginfo::ArgInfo, sv::InferenceState)
sig = argtypes_to_type(arginfo.argtypes)
(; rt, edge) = result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
(; rt, edge, edge_effects) = result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
edge !== nothing && add_backedge!(edge, sv)
tt = closure.typ
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
Expand All @@ -1724,7 +1713,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part
end
end
info = OpaqueClosureCallInfo(match, const_result)
return CallMeta(from_interprocedural!(rt, sv, arginfo, match.spec_types), info)
return CallMeta(from_interprocedural!(rt, sv, arginfo, match.spec_types), edge_effects, info)
end

function most_general_argtypes(closure::PartialOpaque)
Expand All @@ -1746,18 +1735,30 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo,
if isa(ft, PartialOpaque)
newargtypes = copy(argtypes)
newargtypes[1] = ft.env
tristate_merge!(sv, Effects()) # TODO
return abstract_call_opaque_closure(interp, ft, ArgInfo(arginfo.fargs, newargtypes), sv)
body_call = abstract_call_opaque_closure(interp, ft, ArgInfo(arginfo.fargs, newargtypes), sv)
# Analyze implicit type asserts on argument and return type
ftt = ft.typ
(at, rt) = unwrap_unionall(ftt).parameters
if isa(rt, TypeVar)
rt = rewrap_unionall(rt.lb, ftt)
else
rt = rewrap_unionall(rt, ftt)
end
nothrow = body_call.rt rt
if nothrow
nothrow = tuple_tfunc(newargtypes[2:end]) rewrap_unionall(at, ftt)
end
return CallMeta(body_call.rt, Effects(body_call.effects,
nothrow = nothrow ? TRISTATE_UNKNOWN : body_call.effects.nothrow),
body_call.info)
elseif (uft = unwrap_unionall(widenconst(ft)); isa(uft, DataType) && uft.name === typename(Core.OpaqueClosure))
tristate_merge!(sv, Effects()) # TODO
return CallMeta(rewrap_unionall((uft::DataType).parameters[2], widenconst(ft)), false)
return CallMeta(rewrap_unionall((uft::DataType).parameters[2], widenconst(ft)), Effects(), false)
elseif f === nothing
# non-constant function, but the number of arguments is known
# and the ft is not a Builtin or IntrinsicFunction
if hasintersect(widenconst(ft), Union{Builtin, Core.OpaqueClosure})
tristate_merge!(sv, Effects())
add_remark!(interp, sv, "Could not identify method table for call")
return CallMeta(Any, false)
return CallMeta(Any, Effects(), false)
end
max_methods = max_methods === nothing ? get_max_methods(sv.mod, interp) : max_methods
return abstract_call_gf_by_type(interp, nothing, arginfo, argtypes_to_type(argtypes), sv, max_methods)
Expand Down Expand Up @@ -1885,6 +1886,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
t = Bottom
else
callinfo = abstract_call(interp, ArgInfo(ea, argtypes), sv)
tristate_merge!(sv, callinfo.effects)
sv.stmt_info[sv.currpc] = callinfo.info
t = callinfo.rt
end
Expand Down
1 change: 0 additions & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src::Union{IRC
f = argextype(args[1], src)
f = singleton_type(f)
f === nothing && return false
is_return_type(f) && return true
if isa(f, IntrinsicFunction)
intrinsic_effect_free_if_nothrow(f) || return false
return intrinsic_nothrow(f,
Expand Down
3 changes: 0 additions & 3 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1217,9 +1217,6 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
ir[SSAValue(idx)][:inst] = lateres.val
check_effect_free!(ir, idx, lateres.val, rt)
return nothing
elseif is_return_type(sig.f)
check_effect_free!(ir, idx, stmt, rt)
return nothing
end

return stmt, sig
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and any additional information (`call.info`) for a given generic call.
"""
struct CallMeta
rt::Any
effects::Effects
info::Any
end

Expand Down Expand Up @@ -81,7 +82,7 @@ effect-free, including being no-throw (typically because the value was computed
by calling an `@pure` function).
"""
struct MethodResultPure
info::Union{MethodMatchInfo,UnionSplitInfo,Bool}
info::Any
end
let instance = MethodResultPure(false)
global MethodResultPure
Expand Down
Loading

0 comments on commit 37dd084

Please sign in to comment.