Skip to content

Commit

Permalink
effects: taint overlay-ed method's :nonoverlayed effect bit
Browse files Browse the repository at this point in the history
Previously we tainted `:nonoverlayed` bit of the callers of overlay-ed
methods by looking at the method match results, rather than tainting
the overlay-ed methods' effects themselves. This is a bit confusing
since it is not aligned with how the other effect bits are tainted.

Moreover, I am planning to allow `Base.@assume_effects`-override for
`:nonoverlayed` effect bit in the future to solve issues like
JuliaGPU/GPUCompiler.jl#384, and it would be necessary for the solution
to be functional that `:nonoverlayed` effect bit is tainted on the
callee-side as the other effect bits are.

This commit refactors the compiler internal so that we taint
`:nonoverlayed` bit of overlay-ed methods and propagate it to callers.
It turns out that this refactor simplifies the internal implementations
a lot.
  • Loading branch information
aviatesk committed Aug 28, 2023
1 parent f24a93a commit ad24c70
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 63 deletions.
31 changes: 11 additions & 20 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# At this point we are guaranteed to end up throwing on this path,
# which is all that's required for :consistent-cy. Of course, we don't
# know anything else about this statement.
effects = Effects(; consistent=ALWAYS_TRUE, nonoverlayed=!isoverlayed(method_table(interp)))
effects = Effects(; consistent=ALWAYS_TRUE)
return CallMeta(Any, effects, NoCallInfo())
end

Expand All @@ -28,7 +28,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(Any, Effects(), NoCallInfo())
end

(; valid_worlds, applicable, info, nonoverlayed) = matches
(; valid_worlds, applicable, info) = matches
update_valid_age!(sv, valid_worlds)
napplicable = length(applicable)
rettype = Bottom
Expand All @@ -39,7 +39,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
const_results = Union{Nothing,ConstResult}[]
multiple_matches = napplicable > 1
fargs = arginfo.fargs
all_effects = Effects(EFFECTS_TOTAL; nonoverlayed)
all_effects = EFFECTS_TOTAL

𝕃ₚ = ipo_lattice(interp)
for i in 1:napplicable
Expand Down Expand Up @@ -205,7 +205,6 @@ struct MethodMatches
valid_worlds::WorldRange
mt::MethodTable
fullmatch::Bool
nonoverlayed::Bool
end
any_ambig(info::MethodMatchInfo) = info.results.ambig
any_ambig(m::MethodMatches) = any_ambig(m.info)
Expand All @@ -217,7 +216,6 @@ struct UnionSplitMethodMatches
valid_worlds::WorldRange
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
nonoverlayed::Bool
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)

Expand All @@ -233,19 +231,16 @@ function find_matching_methods(𝕃::AbstractLattice,
valid_worlds = WorldRange()
mts = MethodTable[]
fullmatches = Bool[]
nonoverlayed = true
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::MethodTable
result = findall(sig_n, method_table; limit = max_methods)
if result === nothing
matches = findall(sig_n, method_table; limit = max_methods)
if matches === nothing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
(; matches, overlayed) = result
nonoverlayed &= !overlayed
push!(infos, MethodMatchInfo(matches))
for m in matches
push!(applicable, m)
Expand All @@ -271,28 +266,25 @@ function find_matching_methods(𝕃::AbstractLattice,
UnionSplitInfo(infos),
valid_worlds,
mts,
fullmatches,
nonoverlayed)
fullmatches)
else
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
if mt === nothing
return FailedMethodMatch("Could not identify method table for call")
end
mt = mt::MethodTable
result = findall(atype, method_table; limit = max_methods)
if result === nothing
matches = findall(atype, method_table; limit = max_methods)
if matches === nothing
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
(; matches, overlayed) = result
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
return MethodMatches(matches.matches,
MethodMatchInfo(matches),
matches.valid_worlds,
mt,
fullmatch,
!overlayed)
fullmatch)
end
end

Expand Down Expand Up @@ -862,7 +854,7 @@ function concrete_eval_eligible(interp::AbstractInterpreter,
mi = result.edge
if mi !== nothing && is_foldable(effects)
if f !== nothing && is_all_const_arg(arginfo, #=start=#2)
if is_nonoverlayed(mi.def::Method) && (!isoverlayed(method_table(interp)) || is_nonoverlayed(effects))
if is_nonoverlayed(interp) || is_nonoverlayed(effects)
return :concrete_eval
end
# disable concrete-evaluation if this function call is tainted by some overlayed
Expand Down Expand Up @@ -1924,7 +1916,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
lookupsig = rewrap_unionall(Tuple{ft, unwrapped.parameters...}, types)::Type
nargtype = Tuple{ft, nargtype.parameters...}
argtype = Tuple{ft, argtype.parameters...}
match, valid_worlds, overlayed = findsup(lookupsig, method_table(interp))
match, valid_worlds = findsup(lookupsig, method_table(interp))
match === nothing && return CallMeta(Any, Effects(), NoCallInfo())
update_valid_age!(sv, valid_worlds)
method = match.method
Expand Down Expand Up @@ -1955,7 +1947,6 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
end
end
rt = from_interprocedural!(interp, rt, sv, arginfo, sig)
effects = Effects(effects; nonoverlayed = !overlayed)
info = InvokeCallInfo(match, const_result)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
return CallMeta(rt, effects, info)
Expand Down
13 changes: 12 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,11 @@ mutable struct InferenceState
ipo_effects = Effects(ipo_effects; effect_free = ALWAYS_FALSE)
end

restrict_abstract_call_sites = isa(linfo.def, Module)
if def isa Method
ipo_effects = Effects(ipo_effects; nonoverlayed=is_nonoverlayed(def))
end

restrict_abstract_call_sites = isa(def, Module)
@assert cache === :no || cache === :local || cache === :global
cached = cache === :global

Expand All @@ -314,6 +318,13 @@ mutable struct InferenceState
end
end

is_nonoverlayed(m::Method) = !isdefined(m, :external_mt)
is_nonoverlayed(interp::AbstractInterpreter) = !isoverlayed(method_table(interp))
isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
isoverlayed(::InternalMethodTable) = false
isoverlayed(::OverlayMethodTable) = true
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)

is_inferred(sv::InferenceState) = is_inferred(sv.result)
is_inferred(result::InferenceResult) = result.result !== nothing

Expand Down
51 changes: 16 additions & 35 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ function iterate(result::MethodLookupResult, args...)
end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

struct MethodMatchResult
matches::MethodLookupResult
overlayed::Bool
end

"""
struct InternalMethodTable <: MethodTableView
Expand Down Expand Up @@ -55,47 +50,42 @@ Overlays another method table view with an additional local fast path cache that
can respond to repeated, identical queries faster than the original method table.
"""
struct CachedMethodTable{T<:MethodTableView} <: MethodTableView
cache::IdDict{MethodMatchKey, Union{Nothing,MethodMatchResult}}
cache::IdDict{MethodMatchKey, Union{Nothing,MethodLookupResult}}
table::T
end
CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{MethodMatchKey, Union{Nothing,MethodMatchResult}}(), table)
CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{MethodMatchKey, Union{Nothing,MethodLookupResult}}(), table)

"""
findall(sig::Type, view::MethodTableView; limit::Int=-1) ->
MethodMatchResult(matches::MethodLookupResult, overlayed::Bool) or nothing
matches::MethodLookupResult or nothing
Find all methods in the given method table `view` that are applicable to the given signature `sig`.
If no applicable methods are found, an empty result is returned.
If the number of applicable methods exceeded the specified `limit`, `nothing` is returned.
Note that the default setting `limit=-1` does not limit the number of applicable methods.
`overlayed` indicates if any of the matching methods comes from an overlayed method table.
"""
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=-1)
result = _findall(sig, nothing, table.world, limit)
result === nothing && return nothing
return MethodMatchResult(result, false)
end
findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=-1) =
_findall(sig, nothing, table.world, limit)

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=-1)
result = _findall(sig, table.mt, table.world, limit)
result === nothing && return nothing
nr = length(result)
if nr 1 && result[nr].fully_covers
# no need to fall back to the internal method table
return MethodMatchResult(result, true)
return result
end
# fall back to the internal method table
fallback_result = _findall(sig, nothing, table.world, limit)
fallback_result === nothing && return nothing
# merge the fallback match results with the internal method table
return MethodMatchResult(
MethodLookupResult(
vcat(result.matches, fallback_result.matches),
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig),
!isempty(result))
return MethodLookupResult(
vcat(result.matches, fallback_result.matches),
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig)
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt, limit::Int)
Expand Down Expand Up @@ -138,21 +128,19 @@ In both cases `nothing` is returned.
`overlayed` indicates if any of the matching methods comes from an overlayed method table.
"""
function findsup(@nospecialize(sig::Type), table::InternalMethodTable)
return (_findsup(sig, nothing, table.world)..., false)
end
findsup(@nospecialize(sig::Type), table::InternalMethodTable) =
_findsup(sig, nothing, table.world)

function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
match, valid_worlds = _findsup(sig, table.mt, table.world)
match !== nothing && return match, valid_worlds, true
match !== nothing && return match, valid_worlds
# fall back to the internal method table
fallback_match, fallback_valid_worlds = _findsup(sig, nothing, table.world)
return (
fallback_match,
WorldRange(
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
min(valid_worlds.max_world, fallback_valid_worlds.max_world)),
false)
min(valid_worlds.max_world, fallback_valid_worlds.max_world)))
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt)
Expand All @@ -166,10 +154,3 @@ end

# This query is not cached
findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table)

isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
isoverlayed(::InternalMethodTable) = false
isoverlayed(::OverlayMethodTable) = true
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)
isoverlayed(m::Method) = isdefined(m, :external_mt)
is_nonoverlayed(m::Method) = !isoverlayed(m)
2 changes: 1 addition & 1 deletion base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function concrete_eval_invoke(interp::AbstractInterpreter,
argtypes === nothing && return Pair{Any,Bool}(Bottom, false)
effects = decode_effects(code.ipo_purity_bits)
if (is_foldable(effects) && is_all_const_arg(argtypes, #=start=#1) &&
is_nonoverlayed(effects) && is_nonoverlayed(mi.def::Method))
(is_nonoverlayed(interp) || is_nonoverlayed(effects)))
args = collect_const_args(argtypes, #=start=#1)
value = let world = get_world_counter(interp)
try
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2763,7 +2763,7 @@ function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv
if !isa(mt, MethodTable)
return CallMeta(Bool, EFFECTS_THROWS, NoCallInfo())
end
match, valid_worlds, overlayed = findsup(types, method_table(interp))
match, valid_worlds = findsup(types, method_table(interp))
update_valid_age!(sv, valid_worlds)
if match === nothing
rt = Const(false)
Expand Down
5 changes: 2 additions & 3 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1633,12 +1633,11 @@ function infer_effects(@nospecialize(f), @nospecialize(types=default_tt(f));
Core.Compiler.ArgInfo(nothing, argtypes), rt)
end
tt = signature_type(f, types)
result = Core.Compiler.findall(tt, Core.Compiler.method_table(interp))
if result === missing
matches = Core.Compiler.findall(tt, Core.Compiler.method_table(interp))
if matches === missing
# unanalyzable call, return the unknown effects
return Core.Compiler.Effects()
end
(; matches) = result
effects = Core.Compiler.EFFECTS_TOTAL
if matches.ambig || !any(match::Core.MethodMatch->match.fully_covers, matches.matches)
# account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
Expand Down
3 changes: 2 additions & 1 deletion test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ callstrange(::Float64) = strangesin(x)
callstrange(::Nothing) = Core.compilerbarrier(:type, nothing) # trigger inference bail out
callstrange_entry(x) = callstrange(x) # needs to be defined here because of world age
let interp = MTOverlayInterp(Set{Any}())
matches = Core.Compiler.findall(Tuple{typeof(callstrange),Any}, Core.Compiler.method_table(interp)).matches
matches = Core.Compiler.findall(Tuple{typeof(callstrange),Any}, Core.Compiler.method_table(interp))
@test matches !== nothing
@test Core.Compiler.length(matches) == 2
if Core.Compiler.getindex(matches, 1).method == which(callstrange, (Nothing,))
@test Base.infer_effects(callstrange_entry, (Any,); interp) |> !Core.Compiler.is_nonoverlayed
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/datastructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Test
sig = Tuple{typeof(*), Any, Any}
result1 = Core.Compiler.findall(sig, table; limit=-1)
result2 = Core.Compiler.findall(sig, table; limit=Core.Compiler.InferenceParams().max_methods)
@test result1 !== nothing && !Core.Compiler.isempty(result1.matches)
@test result1 !== nothing && !Core.Compiler.isempty(result1)
@test result2 === nothing
end

Expand Down

0 comments on commit ad24c70

Please sign in to comment.