Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

abstract_apply: Don't drop effects of iterate'd calls #47846

Merged
merged 2 commits into from
Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 38 additions & 24 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,14 @@ function ssa_def_slot(@nospecialize(arg), sv::InferenceState)
return arg
end

struct AbstractIterationResult
cti::Vector{Any}
info::MaybeAbstractIterationInfo
ai_effects::Effects
end
AbstractIterationResult(cti::Vector{Any}, info::MaybeAbstractIterationInfo) =
AbstractIterationResult(cti, info, EFFECTS_TOTAL)

# `typ` is the inferred type for expression `arg`.
# if the expression constructs a container (e.g. `svec(x,y,z)`),
# refine its type to an array of element types.
Expand All @@ -1352,14 +1360,14 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
if isa(typ, PartialStruct)
widet = typ.typ
if isa(widet, DataType) && widet.name === Tuple.name
return typ.fields, nothing
return AbstractIterationResult(typ.fields, nothing)
end
end

if isa(typ, Const)
val = typ.val
if isa(val, SimpleVector) || isa(val, Tuple)
return Any[ Const(val[i]) for i in 1:length(val) ], nothing # avoid making a tuple Generator here!
return AbstractIterationResult(Any[ Const(val[i]) for i in 1:length(val) ], nothing) # avoid making a tuple Generator here!
end
end

Expand All @@ -1374,12 +1382,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
if isa(tti, Union)
utis = uniontypes(tti)
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return Any[Vararg{Any}], nothing
return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′)
end
ltp = length((utis[1]::DataType).parameters)
for t in utis
if length((t::DataType).parameters) != ltp
return Any[Vararg{Any}], nothing
return AbstractIterationResult(Any[Vararg{Any}], nothing)
end
end
result = Any[ Union{} for _ in 1:ltp ]
Expand All @@ -1390,12 +1398,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
end
end
return result, nothing
return AbstractIterationResult(result, nothing)
elseif tti0 <: Tuple
if isa(tti0, DataType)
return Any[ p for p in tti0.parameters ], nothing
return AbstractIterationResult(Any[ p for p in tti0.parameters ], nothing)
elseif !isa(tti, DataType)
return Any[Vararg{Any}], nothing
return AbstractIterationResult(Any[Vararg{Any}], nothing)
else
len = length(tti.parameters)
last = tti.parameters[len]
Expand All @@ -1404,12 +1412,14 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
if va
elts[len] = Vararg{elts[len]}
end
return elts, nothing
return AbstractIterationResult(elts, nothing)
end
elseif tti0 === SimpleVector || tti0 === Any
return Any[Vararg{Any}], nothing
elseif tti0 === SimpleVector
return AbstractIterationResult(Any[Vararg{Any}], nothing)
elseif tti0 === Any
return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′)
elseif tti0 <: Array
return Any[Vararg{eltype(tti0)}], nothing
return AbstractIterationResult(Any[Vararg{eltype(tti0)}], nothing)
else
return abstract_iteration(interp, itft, typ, sv)
end
Expand All @@ -1420,7 +1430,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
if isa(itft, Const)
iteratef = itft.val
else
return Any[Vararg{Any}], nothing
return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′)
end
@assert !isvarargtype(itertype)
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[itft, itertype]), StmtInfo(true), sv)
Expand All @@ -1430,7 +1440,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# WARNING: Changes to the iteration protocol must be reflected here,
# this is not just an optimization.
# TODO: this doesn't realize that Array, SimpleVector, Tuple, and NamedTuple do not use the iterate protocol
stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, call.effects, info)])
stateordonet === Bottom && return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, call.effects, info)], true))
valtype = statetype = Bottom
ret = Any[]
calls = CallMeta[call]
Expand All @@ -1440,7 +1450,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# length iterators, or interesting prefix
while true
if stateordonet_widened === Nothing
return ret, AbstractIterationInfo(calls)
return AbstractIterationResult(ret, AbstractIterationInfo(calls, true))
end
if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).max_tuple_splat
break
Expand All @@ -1452,7 +1462,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# If there's no new information in this statetype, don't bother continuing,
# the iterator won't be finite.
if ⊑(typeinf_lattice(interp), nstatetype, statetype)
return Any[Bottom], nothing
return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(calls, false), EFFECTS_THROWS)
end
valtype = getfield_tfunc(typeinf_lattice(interp), stateordonet, Const(1))
push!(ret, valtype)
Expand Down Expand Up @@ -1482,7 +1492,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# ... but cannot terminate
if !may_have_terminated
# ... and cannot have terminated prior to this loop
return Any[Bottom], nothing
return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(calls, false), EFFECTS_UNKNOWN′)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EFFECTS_THROWS seems to be more consistent with the change on L1465 above?

Suggested change
return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(calls, false), EFFECTS_UNKNOWN′)
return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(calls, false), EFFECTS_THROWS)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, we have proven that we cannot terminate, but I'm not sure we've proven that the state has fixpointed, meaning that our call effects cover the largest possible effect set the call could produce during its execution.

else
# iterator may have terminated prior to this loop, but not during it
valtype = Bottom
Expand All @@ -1492,13 +1502,15 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
end
valtype = tmerge(valtype, nounion.parameters[1])
statetype = tmerge(statetype, nounion.parameters[2])
stateordonet = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), StmtInfo(true), sv).rt
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), StmtInfo(true), sv)
push!(calls, call)
stateordonet = call.rt
stateordonet_widened = widenconst(stateordonet)
end
if valtype !== Union{}
push!(ret, Vararg{valtype})
end
return ret, nothing
return AbstractIterationResult(ret, AbstractIterationInfo(calls, false))
end

# do apply(af, fargs...), where af is a function value
Expand Down Expand Up @@ -1529,13 +1541,9 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
infos′ = Vector{MaybeAbstractIterationInfo}[]
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
if !isvarargtype(ti)
cti_info = precise_container_type(interp, itft, ti, sv)
cti = cti_info[1]::Vector{Any}
info = cti_info[2]::MaybeAbstractIterationInfo
(;cti, info, ai_effects) = precise_container_type(interp, itft, ti, sv)
else
cti_info = precise_container_type(interp, itft, unwrapva(ti), sv)
cti = cti_info[1]::Vector{Any}
info = cti_info[2]::MaybeAbstractIterationInfo
(;cti, info, ai_effects) = precise_container_type(interp, itft, unwrapva(ti), sv)
# We can't represent a repeating sequence of the same types,
# so tmerge everything together to get one type that represents
# everything.
Expand All @@ -1548,6 +1556,12 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
end
cti = Any[Vararg{argt}]
end
effects = merge_effects(effects, ai_effects)
if info !== nothing
Comment on lines +1559 to +1560
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
effects = merge_effects(effects, ai_effects)
if info !== nothing
# merge effects of the `iterate` call
effects = merge_effects(effects, ai_effects)
# merge effects of call(s) with the iterated arguments
if info !== nothing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really the effects of the iterate call. It's more the effects of imprecision of the iterate call. The info.each, are all methods of iterate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I understand.

for call in info.each
effects = merge_effects(effects, call.effects)
end
end
if any(@nospecialize(t) -> t === Bottom, cti)
continue
end
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ function rewrite_apply_exprargs!(todo::Vector{Pair{Int,Any}},
def = argexprs[i]
def_type = argtypes[i]
thisarginfo = arginfos[i-arg_start]
if thisarginfo === nothing
if thisarginfo === nothing || !thisarginfo.complete
if def_type isa PartialStruct
# def_type.typ <: Tuple is assumed
def_argtypes = def_type.fields
Expand Down Expand Up @@ -1134,9 +1134,9 @@ function inline_apply!(todo::Vector{Pair{Int,Any}},
for i = (arg_start + 1):length(argtypes)
thisarginfo = nothing
if !is_valid_type_for_apply_rewrite(argtypes[i], state.params)
if isa(info, ApplyCallInfo) && info.arginfo[i-arg_start] !== nothing
thisarginfo = info.arginfo[i-arg_start]
else
isa(info, ApplyCallInfo) || return nothing
thisarginfo = info.arginfo[i-arg_start]
if thisarginfo === nothing || !thisarginfo.complete
return nothing
end
end
Expand Down
16 changes: 13 additions & 3 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,17 @@ function kill_def_use!(tpdum::TwoPhaseDefUseMap, def::Int, use::Int)
if !tpdum.complete
tpdum.ssa_uses[def] -= 1
else
@assert false && "TODO"
range = tpdum.ssa_uses[def]:(def == length(tpdum.ssa_uses) ? length(tpdum.data) : (tpdum.ssa_uses[def + 1] - 1))
# TODO: Sorted
useidx = findfirst(idx->tpdum.data[idx] == use, range)
@assert useidx !== nothing
idx = range[useidx]
while idx < lastindex(range)
ndata = tpdum.data[idx+1]
ndata == 0 && break
tpdum.data[idx] = ndata
end
tpdum.data[idx + 1] = 0
end
end
kill_def_use!(tpdum::TwoPhaseDefUseMap, def::SSAValue, use::Int) =
Expand Down Expand Up @@ -262,11 +272,11 @@ function process_terminator!(ir::IRCode, idx::Int, bb::Int,
end
return false
elseif isa(inst, GotoNode)
backedge = inst.label < bb
backedge = inst.label <= bb
!backedge && push!(ip, inst.label)
return backedge
elseif isa(inst, GotoIfNot)
backedge = inst.dest < bb
backedge = inst.dest <= bb
!backedge && push!(ip, inst.dest)
push!(ip, bb + 1)
return backedge
Expand Down
1 change: 1 addition & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ Each (abstract) call to `iterate`, corresponds to one entry in `ainfo.each::Vect
"""
struct AbstractIterationInfo
each::Vector{CallMeta}
complete::Bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change is necessary for this PR or it is a separate fix on the inlining algorithm?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds all the calls into the .each field (partly to look at the affects, but mostly for Cthulhu's benefit). Previously the inlining algorithm used the presence or absence of this info to determine whether or not we had the complete set of iteration items. Now we need to keep track of it separately.

end

const MaybeAbstractIterationInfo = Union{Nothing, AbstractIterationInfo}
Expand Down
16 changes: 16 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4633,3 +4633,19 @@ end |> only === Type{Float64}
# Issue #46839: `abstract_invoke` should handle incorrect call type
@test only(Base.return_types(()->invoke(BitSet, Any, x), ())) === Union{}
@test only(Base.return_types(()->invoke(BitSet, Union{Tuple{Int32},Tuple{Int64}}, 1), ())) === Union{}

# Issue #47688: Abstract iteration should take into account `iterate` effects
global it_count47688 = 0
struct CountsIterate47688{N}; end
function Base.iterate(::CountsIterate47688{N}, n=0) where N
global it_count47688 += 1
n <= N ? (n, n+1) : nothing
end
foo47688() = tuple(CountsIterate47688{5}()...)
bar47688() = foo47688()
@test only(Base.return_types(bar47688)) == NTuple{6, Int}
@test it_count47688 == 0
@test isa(bar47688(), NTuple{6, Int})
@test it_count47688 == 7
@test isa(foo47688(), NTuple{6, Int})
@test it_count47688 == 14
Keno marked this conversation as resolved.
Show resolved Hide resolved