Skip to content

Commit

Permalink
abstract_apply: Don't drop effects of iterate'd calls (#47846)
Browse files Browse the repository at this point in the history
We were accidentally dropping the effects of calls from
`iterate` calls performed during abstract_iteration. This
allowed calls that were not actually eligible for (semi-)concrete
evaluation to go through that path anyway. This could cause
incorrect results (see test), though it was usually fine, since
iterate call tend to not have side effects. It was noticed
however in #47688, because it forced irinterp down a path
that was not meant to be reachable (resulting in a TODO
error message). For good measure, let's also address this
todo (since it is reachable by external absint if they want),
but the missing effect propagation was the more serious bug
here.

(cherry picked from commit 2a0d58a)
  • Loading branch information
Keno authored and KristofferC committed Mar 7, 2023
1 parent 3b2e0d8 commit e970518
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 32 deletions.
67 changes: 42 additions & 25 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1304,21 +1304,32 @@ function ssa_def_slot(@nospecialize(arg), sv::InferenceState)
return arg
end

struct AbstractIterationResult
cti::Vector{Any}
info::MaybeAbstractIterationInfo
ai_effects::Effects
end
AbstractIterationResult(cti::Vector{Any}, info::MaybeAbstractIterationInfo) =
AbstractIterationResult(cti, info, EFFECTS_TOTAL)

# `typ` is the inferred type for expression `arg`.
# if the expression constructs a container (e.g. `svec(x,y,z)`),
# refine its type to an array of element types.
# Union of Tuples of the same length is converted to Tuple of Unions.
# returns an array of types
function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(typ),
sv::Union{InferenceState, IRCode})
if isa(typ, PartialStruct) && typ.typ.name === Tuple.name
return typ.fields, nothing
if isa(typ, PartialStruct)
widet = typ.typ
if isa(widet, DataType) && widet.name === Tuple.name
return AbstractIterationResult(typ.fields, nothing)
end
end

if isa(typ, Const)
val = typ.val
if isa(val, SimpleVector) || isa(val, Tuple)
return Any[ Const(val[i]) for i in 1:length(val) ], nothing # avoid making a tuple Generator here!
return AbstractIterationResult(Any[ Const(val[i]) for i in 1:length(val) ], nothing) # avoid making a tuple Generator here!
end
end

Expand All @@ -1333,12 +1344,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
if isa(tti, Union)
utis = uniontypes(tti)
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return Any[Vararg{Any}], nothing
return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′)
end
ltp = length((utis[1]::DataType).parameters)
for t in utis
if length((t::DataType).parameters) != ltp
return Any[Vararg{Any}], nothing
return AbstractIterationResult(Any[Vararg{Any}], nothing)
end
end
result = Any[ Union{} for _ in 1:ltp ]
Expand All @@ -1349,12 +1360,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
end
end
return result, nothing
return AbstractIterationResult(result, nothing)
elseif tti0 <: Tuple
if isa(tti0, DataType)
return Any[ p for p in tti0.parameters ], nothing
return AbstractIterationResult(Any[ p for p in tti0.parameters ], nothing)
elseif !isa(tti, DataType)
return Any[Vararg{Any}], nothing
return AbstractIterationResult(Any[Vararg{Any}], nothing)
else
len = length(tti.parameters)
last = tti.parameters[len]
Expand All @@ -1363,12 +1374,14 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
if va
elts[len] = Vararg{elts[len]}
end
return elts, nothing
return AbstractIterationResult(elts, nothing)
end
elseif tti0 === SimpleVector || tti0 === Any
return Any[Vararg{Any}], nothing
elseif tti0 === SimpleVector
return AbstractIterationResult(Any[Vararg{Any}], nothing)
elseif tti0 === Any
return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′)
elseif tti0 <: Array
return Any[Vararg{eltype(tti0)}], nothing
return AbstractIterationResult(Any[Vararg{eltype(tti0)}], nothing)
else
return abstract_iteration(interp, itft, typ, sv)
end
Expand All @@ -1379,7 +1392,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
if isa(itft, Const)
iteratef = itft.val
else
return Any[Vararg{Any}], nothing
return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′)
end
@assert !isvarargtype(itertype)
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[itft, itertype]), StmtInfo(true), sv)
Expand All @@ -1389,7 +1402,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# WARNING: Changes to the iteration protocol must be reflected here,
# this is not just an optimization.
# TODO: this doesn't realize that Array, SimpleVector, Tuple, and NamedTuple do not use the iterate protocol
stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, call.effects, info)])
stateordonet === Bottom && return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, call.effects, info)], true))
valtype = statetype = Bottom
ret = Any[]
calls = CallMeta[call]
Expand All @@ -1399,7 +1412,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# length iterators, or interesting prefix
while true
if stateordonet_widened === Nothing
return ret, AbstractIterationInfo(calls)
return AbstractIterationResult(ret, AbstractIterationInfo(calls, true))
end
if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).MAX_TUPLE_SPLAT
break
Expand All @@ -1411,7 +1424,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# If there's no new information in this statetype, don't bother continuing,
# the iterator won't be finite.
if (typeinf_lattice(interp), nstatetype, statetype)
return Any[Bottom], nothing
return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(calls, false), EFFECTS_THROWS)
end
valtype = getfield_tfunc(typeinf_lattice(interp), stateordonet, Const(1))
push!(ret, valtype)
Expand Down Expand Up @@ -1441,7 +1454,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# ... but cannot terminate
if !may_have_terminated
# ... and cannot have terminated prior to this loop
return Any[Bottom], nothing
return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(calls, false), EFFECTS_UNKNOWN′)
else
# iterator may have terminated prior to this loop, but not during it
valtype = Bottom
Expand All @@ -1451,13 +1464,15 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
end
valtype = tmerge(valtype, nounion.parameters[1])
statetype = tmerge(statetype, nounion.parameters[2])
stateordonet = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), StmtInfo(true), sv).rt
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), StmtInfo(true), sv)
push!(calls, call)
stateordonet = call.rt
stateordonet_widened = widenconst(stateordonet)
end
if valtype !== Union{}
push!(ret, Vararg{valtype})
end
return ret, nothing
return AbstractIterationResult(ret, AbstractIterationInfo(calls, false))
end

# do apply(af, fargs...), where af is a function value
Expand Down Expand Up @@ -1488,13 +1503,9 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
infos′ = Vector{MaybeAbstractIterationInfo}[]
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
if !isvarargtype(ti)
cti_info = precise_container_type(interp, itft, ti, sv)
cti = cti_info[1]::Vector{Any}
info = cti_info[2]::MaybeAbstractIterationInfo
(;cti, info, ai_effects) = precise_container_type(interp, itft, ti, sv)
else
cti_info = precise_container_type(interp, itft, unwrapva(ti), sv)
cti = cti_info[1]::Vector{Any}
info = cti_info[2]::MaybeAbstractIterationInfo
(;cti, info, ai_effects) = precise_container_type(interp, itft, unwrapva(ti), sv)
# We can't represent a repeating sequence of the same types,
# so tmerge everything together to get one type that represents
# everything.
Expand All @@ -1507,6 +1518,12 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
end
cti = Any[Vararg{argt}]
end
effects = merge_effects(effects, ai_effects)
if info !== nothing
for call in info.each
effects = merge_effects(effects, call.effects)
end
end
if any(@nospecialize(t) -> t === Bottom, cti)
continue
end
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ function rewrite_apply_exprargs!(todo::Vector{Pair{Int,Any}},
def = argexprs[i]
def_type = argtypes[i]
thisarginfo = arginfos[i-arg_start]
if thisarginfo === nothing
if thisarginfo === nothing || !thisarginfo.complete
if def_type isa PartialStruct
# def_type.typ <: Tuple is assumed
def_argtypes = def_type.fields
Expand Down Expand Up @@ -1141,9 +1141,9 @@ function inline_apply!(todo::Vector{Pair{Int,Any}},
for i = (arg_start + 1):length(argtypes)
thisarginfo = nothing
if !is_valid_type_for_apply_rewrite(argtypes[i], state.params)
if isa(info, ApplyCallInfo) && info.arginfo[i-arg_start] !== nothing
thisarginfo = info.arginfo[i-arg_start]
else
isa(info, ApplyCallInfo) || return nothing
thisarginfo = info.arginfo[i-arg_start]
if thisarginfo === nothing || !thisarginfo.complete
return nothing
end
end
Expand Down
16 changes: 13 additions & 3 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,17 @@ function kill_def_use!(tpdum::TwoPhaseDefUseMap, def::Int, use::Int)
if !tpdum.complete
tpdum.ssa_uses[def] -= 1
else
@assert false && "TODO"
range = tpdum.ssa_uses[def]:(def == length(tpdum.ssa_uses) ? length(tpdum.data) : (tpdum.ssa_uses[def + 1] - 1))
# TODO: Sorted
useidx = findfirst(idx->tpdum.data[idx] == use, range)
@assert useidx !== nothing
idx = range[useidx]
while idx < lastindex(range)
ndata = tpdum.data[idx+1]
ndata == 0 && break
tpdum.data[idx] = ndata
end
tpdum.data[idx + 1] = 0
end
end
kill_def_use!(tpdum::TwoPhaseDefUseMap, def::SSAValue, use::Int) =
Expand Down Expand Up @@ -261,11 +271,11 @@ function process_terminator!(ir::IRCode, idx::Int, bb::Int,
end
return false
elseif isa(inst, GotoNode)
backedge = inst.label < bb
backedge = inst.label <= bb
!backedge && push!(ip, inst.label)
return backedge
elseif isa(inst, GotoIfNot)
backedge = inst.dest < bb
backedge = inst.dest <= bb
!backedge && push!(ip, inst.dest)
push!(ip, bb + 1)
return backedge
Expand Down
1 change: 1 addition & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ Each (abstract) call to `iterate`, corresponds to one entry in `ainfo.each::Vect
"""
struct AbstractIterationInfo
each::Vector{CallMeta}
complete::Bool
end

const MaybeAbstractIterationInfo = Union{Nothing, AbstractIterationInfo}
Expand Down
16 changes: 16 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4293,3 +4293,19 @@ unknown_sparam_nothrow2(x::Ref{Ref{T}}) where T = @isdefined(T) ? T::Type : noth
@test only(Base.return_types(unknown_sparam_throw, (Any,))) === Union{Nothing,Type}
@test only(Base.return_types(unknown_sparam_nothrow1, (Ref,))) === Type
@test only(Base.return_types(unknown_sparam_nothrow2, (Ref{Ref{T}} where T,))) === Type

# Issue #47688: Abstract iteration should take into account `iterate` effects
global it_count47688 = 0
struct CountsIterate47688{N}; end
function Base.iterate(::CountsIterate47688{N}, n=0) where N
global it_count47688 += 1
n <= N ? (n, n+1) : nothing
end
foo47688() = tuple(CountsIterate47688{5}()...)
bar47688() = foo47688()
@test only(Base.return_types(bar47688)) == NTuple{6, Int}
@test it_count47688 == 0
@test isa(bar47688(), NTuple{6, Int})
@test it_count47688 == 7
@test isa(foo47688(), NTuple{6, Int})
@test it_count47688 == 14

0 comments on commit e970518

Please sign in to comment.