From ad24c70f983cf3aec9b654d0163dba63bb9581f6 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 28 Aug 2023 15:49:03 +0900 Subject: [PATCH] effects: taint overlay-ed method's `:nonoverlayed` effect bit Previously we tainted `:nonoverlayed` bit of the callers of overlay-ed methods by looking at the method match results, rather than tainting the overlay-ed methods' effects themselves. This is a bit confusing since it is not aligned with how the other effect bits are tainted. Moreover, I am planning to allow `Base.@assume_effects`-override for `:nonoverlayed` effect bit in the future to solve issues like JuliaGPU/GPUCompiler.jl#384, and it would be necessary for the solution to be functional that `:nonoverlayed` effect bit is tainted on the callee-side as the other effect bits are. This commit refactors the compiler internal so that we taint `:nonoverlayed` bit of overlay-ed methods and propagate it to callers. It turns out that this refactor simplifies the internal implementations a lot. --- base/compiler/abstractinterpretation.jl | 31 ++++++--------- base/compiler/inferencestate.jl | 13 ++++++- base/compiler/methodtable.jl | 51 ++++++++----------------- base/compiler/ssair/irinterp.jl | 2 +- base/compiler/tfuncs.jl | 2 +- base/reflection.jl | 5 +-- test/compiler/AbstractInterpreter.jl | 3 +- test/compiler/datastructures.jl | 2 +- 8 files changed, 46 insertions(+), 63 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 39fea90cbe9936..1893abd7820eee 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -16,7 +16,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # At this point we are guaranteed to end up throwing on this path, # which is all that's required for :consistent-cy. Of course, we don't # know anything else about this statement. - effects = Effects(; consistent=ALWAYS_TRUE, nonoverlayed=!isoverlayed(method_table(interp))) + effects = Effects(; consistent=ALWAYS_TRUE) return CallMeta(Any, effects, NoCallInfo()) end @@ -28,7 +28,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), return CallMeta(Any, Effects(), NoCallInfo()) end - (; valid_worlds, applicable, info, nonoverlayed) = matches + (; valid_worlds, applicable, info) = matches update_valid_age!(sv, valid_worlds) napplicable = length(applicable) rettype = Bottom @@ -39,7 +39,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), const_results = Union{Nothing,ConstResult}[] multiple_matches = napplicable > 1 fargs = arginfo.fargs - all_effects = Effects(EFFECTS_TOTAL; nonoverlayed) + all_effects = EFFECTS_TOTAL 𝕃ₚ = ipo_lattice(interp) for i in 1:napplicable @@ -205,7 +205,6 @@ struct MethodMatches valid_worlds::WorldRange mt::MethodTable fullmatch::Bool - nonoverlayed::Bool end any_ambig(info::MethodMatchInfo) = info.results.ambig any_ambig(m::MethodMatches) = any_ambig(m.info) @@ -217,7 +216,6 @@ struct UnionSplitMethodMatches valid_worlds::WorldRange mts::Vector{MethodTable} fullmatches::Vector{Bool} - nonoverlayed::Bool end any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches) @@ -233,19 +231,16 @@ function find_matching_methods(𝕃::AbstractLattice, valid_worlds = WorldRange() mts = MethodTable[] fullmatches = Bool[] - nonoverlayed = true for i in 1:length(split_argtypes) arg_n = split_argtypes[i]::Vector{Any} sig_n = argtypes_to_type(arg_n) mt = ccall(:jl_method_table_for, Any, (Any,), sig_n) mt === nothing && return FailedMethodMatch("Could not identify method table for call") mt = mt::MethodTable - result = findall(sig_n, method_table; limit = max_methods) - if result === nothing + matches = findall(sig_n, method_table; limit = max_methods) + if matches === nothing return FailedMethodMatch("For one of the union split cases, too many methods matched") end - (; matches, overlayed) = result - nonoverlayed &= !overlayed push!(infos, MethodMatchInfo(matches)) for m in matches push!(applicable, m) @@ -271,28 +266,25 @@ function find_matching_methods(𝕃::AbstractLattice, UnionSplitInfo(infos), valid_worlds, mts, - fullmatches, - nonoverlayed) + fullmatches) else mt = ccall(:jl_method_table_for, Any, (Any,), atype) if mt === nothing return FailedMethodMatch("Could not identify method table for call") end mt = mt::MethodTable - result = findall(atype, method_table; limit = max_methods) - if result === nothing + matches = findall(atype, method_table; limit = max_methods) + if matches === nothing # this means too many methods matched # (assume this will always be true, so we don't compute / update valid age in this case) return FailedMethodMatch("Too many methods matched") end - (; matches, overlayed) = result fullmatch = any(match::MethodMatch->match.fully_covers, matches) return MethodMatches(matches.matches, MethodMatchInfo(matches), matches.valid_worlds, mt, - fullmatch, - !overlayed) + fullmatch) end end @@ -862,7 +854,7 @@ function concrete_eval_eligible(interp::AbstractInterpreter, mi = result.edge if mi !== nothing && is_foldable(effects) if f !== nothing && is_all_const_arg(arginfo, #=start=#2) - if is_nonoverlayed(mi.def::Method) && (!isoverlayed(method_table(interp)) || is_nonoverlayed(effects)) + if is_nonoverlayed(interp) || is_nonoverlayed(effects) return :concrete_eval end # disable concrete-evaluation if this function call is tainted by some overlayed @@ -1924,7 +1916,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn lookupsig = rewrap_unionall(Tuple{ft, unwrapped.parameters...}, types)::Type nargtype = Tuple{ft, nargtype.parameters...} argtype = Tuple{ft, argtype.parameters...} - match, valid_worlds, overlayed = findsup(lookupsig, method_table(interp)) + match, valid_worlds = findsup(lookupsig, method_table(interp)) match === nothing && return CallMeta(Any, Effects(), NoCallInfo()) update_valid_age!(sv, valid_worlds) method = match.method @@ -1955,7 +1947,6 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn end end rt = from_interprocedural!(interp, rt, sv, arginfo, sig) - effects = Effects(effects; nonoverlayed = !overlayed) info = InvokeCallInfo(match, const_result) edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge) return CallMeta(rt, effects, info) diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 86d6c046b45536..a36228707cb14c 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -296,7 +296,11 @@ mutable struct InferenceState ipo_effects = Effects(ipo_effects; effect_free = ALWAYS_FALSE) end - restrict_abstract_call_sites = isa(linfo.def, Module) + if def isa Method + ipo_effects = Effects(ipo_effects; nonoverlayed=is_nonoverlayed(def)) + end + + restrict_abstract_call_sites = isa(def, Module) @assert cache === :no || cache === :local || cache === :global cached = cache === :global @@ -314,6 +318,13 @@ mutable struct InferenceState end end +is_nonoverlayed(m::Method) = !isdefined(m, :external_mt) +is_nonoverlayed(interp::AbstractInterpreter) = !isoverlayed(method_table(interp)) +isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface") +isoverlayed(::InternalMethodTable) = false +isoverlayed(::OverlayMethodTable) = true +isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table) + is_inferred(sv::InferenceState) = is_inferred(sv.result) is_inferred(result::InferenceResult) = result.result !== nothing diff --git a/base/compiler/methodtable.jl b/base/compiler/methodtable.jl index 804cd395532f30..470fdc413e3201 100644 --- a/base/compiler/methodtable.jl +++ b/base/compiler/methodtable.jl @@ -16,11 +16,6 @@ function iterate(result::MethodLookupResult, args...) end getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch -struct MethodMatchResult - matches::MethodLookupResult - overlayed::Bool -end - """ struct InternalMethodTable <: MethodTableView @@ -55,14 +50,14 @@ Overlays another method table view with an additional local fast path cache that can respond to repeated, identical queries faster than the original method table. """ struct CachedMethodTable{T<:MethodTableView} <: MethodTableView - cache::IdDict{MethodMatchKey, Union{Nothing,MethodMatchResult}} + cache::IdDict{MethodMatchKey, Union{Nothing,MethodLookupResult}} table::T end -CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{MethodMatchKey, Union{Nothing,MethodMatchResult}}(), table) +CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{MethodMatchKey, Union{Nothing,MethodLookupResult}}(), table) """ findall(sig::Type, view::MethodTableView; limit::Int=-1) -> - MethodMatchResult(matches::MethodLookupResult, overlayed::Bool) or nothing + matches::MethodLookupResult or nothing Find all methods in the given method table `view` that are applicable to the given signature `sig`. If no applicable methods are found, an empty result is returned. @@ -70,11 +65,8 @@ If the number of applicable methods exceeded the specified `limit`, `nothing` is Note that the default setting `limit=-1` does not limit the number of applicable methods. `overlayed` indicates if any of the matching methods comes from an overlayed method table. """ -function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=-1) - result = _findall(sig, nothing, table.world, limit) - result === nothing && return nothing - return MethodMatchResult(result, false) -end +findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=-1) = + _findall(sig, nothing, table.world, limit) function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=-1) result = _findall(sig, table.mt, table.world, limit) @@ -82,20 +74,18 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int nr = length(result) if nr ≥ 1 && result[nr].fully_covers # no need to fall back to the internal method table - return MethodMatchResult(result, true) + return result end # fall back to the internal method table fallback_result = _findall(sig, nothing, table.world, limit) fallback_result === nothing && return nothing # merge the fallback match results with the internal method table - return MethodMatchResult( - MethodLookupResult( - vcat(result.matches, fallback_result.matches), - WorldRange( - max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world), - min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)), - result.ambig | fallback_result.ambig), - !isempty(result)) + return MethodLookupResult( + vcat(result.matches, fallback_result.matches), + WorldRange( + max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world), + min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)), + result.ambig | fallback_result.ambig) end function _findall(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt, limit::Int) @@ -138,21 +128,19 @@ In both cases `nothing` is returned. `overlayed` indicates if any of the matching methods comes from an overlayed method table. """ -function findsup(@nospecialize(sig::Type), table::InternalMethodTable) - return (_findsup(sig, nothing, table.world)..., false) -end +findsup(@nospecialize(sig::Type), table::InternalMethodTable) = + _findsup(sig, nothing, table.world) function findsup(@nospecialize(sig::Type), table::OverlayMethodTable) match, valid_worlds = _findsup(sig, table.mt, table.world) - match !== nothing && return match, valid_worlds, true + match !== nothing && return match, valid_worlds # fall back to the internal method table fallback_match, fallback_valid_worlds = _findsup(sig, nothing, table.world) return ( fallback_match, WorldRange( max(valid_worlds.min_world, fallback_valid_worlds.min_world), - min(valid_worlds.max_world, fallback_valid_worlds.max_world)), - false) + min(valid_worlds.max_world, fallback_valid_worlds.max_world))) end function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt) @@ -166,10 +154,3 @@ end # This query is not cached findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table) - -isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface") -isoverlayed(::InternalMethodTable) = false -isoverlayed(::OverlayMethodTable) = true -isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table) -isoverlayed(m::Method) = isdefined(m, :external_mt) -is_nonoverlayed(m::Method) = !isoverlayed(m) diff --git a/base/compiler/ssair/irinterp.jl b/base/compiler/ssair/irinterp.jl index ec3b769a9992af..e174176695e662 100644 --- a/base/compiler/ssair/irinterp.jl +++ b/base/compiler/ssair/irinterp.jl @@ -15,7 +15,7 @@ function concrete_eval_invoke(interp::AbstractInterpreter, argtypes === nothing && return Pair{Any,Bool}(Bottom, false) effects = decode_effects(code.ipo_purity_bits) if (is_foldable(effects) && is_all_const_arg(argtypes, #=start=#1) && - is_nonoverlayed(effects) && is_nonoverlayed(mi.def::Method)) + (is_nonoverlayed(interp) || is_nonoverlayed(effects))) args = collect_const_args(argtypes, #=start=#1) value = let world = get_world_counter(interp) try diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 20996be4faaedf..4857f4bdcc30dd 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -2763,7 +2763,7 @@ function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv if !isa(mt, MethodTable) return CallMeta(Bool, EFFECTS_THROWS, NoCallInfo()) end - match, valid_worlds, overlayed = findsup(types, method_table(interp)) + match, valid_worlds = findsup(types, method_table(interp)) update_valid_age!(sv, valid_worlds) if match === nothing rt = Const(false) diff --git a/base/reflection.jl b/base/reflection.jl index ffd3882bf74ab9..1d1399c37f5f96 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -1633,12 +1633,11 @@ function infer_effects(@nospecialize(f), @nospecialize(types=default_tt(f)); Core.Compiler.ArgInfo(nothing, argtypes), rt) end tt = signature_type(f, types) - result = Core.Compiler.findall(tt, Core.Compiler.method_table(interp)) - if result === missing + matches = Core.Compiler.findall(tt, Core.Compiler.method_table(interp)) + if matches === missing # unanalyzable call, return the unknown effects return Core.Compiler.Effects() end - (; matches) = result effects = Core.Compiler.EFFECTS_TOTAL if matches.ambig || !any(match::Core.MethodMatch->match.fully_covers, matches.matches) # account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature. diff --git a/test/compiler/AbstractInterpreter.jl b/test/compiler/AbstractInterpreter.jl index edb6546499779e..5df755af5dbf76 100644 --- a/test/compiler/AbstractInterpreter.jl +++ b/test/compiler/AbstractInterpreter.jl @@ -55,7 +55,8 @@ callstrange(::Float64) = strangesin(x) callstrange(::Nothing) = Core.compilerbarrier(:type, nothing) # trigger inference bail out callstrange_entry(x) = callstrange(x) # needs to be defined here because of world age let interp = MTOverlayInterp(Set{Any}()) - matches = Core.Compiler.findall(Tuple{typeof(callstrange),Any}, Core.Compiler.method_table(interp)).matches + matches = Core.Compiler.findall(Tuple{typeof(callstrange),Any}, Core.Compiler.method_table(interp)) + @test matches !== nothing @test Core.Compiler.length(matches) == 2 if Core.Compiler.getindex(matches, 1).method == which(callstrange, (Nothing,)) @test Base.infer_effects(callstrange_entry, (Any,); interp) |> !Core.Compiler.is_nonoverlayed diff --git a/test/compiler/datastructures.jl b/test/compiler/datastructures.jl index 8dbaee61503d0c..3f867ab59b2b37 100644 --- a/test/compiler/datastructures.jl +++ b/test/compiler/datastructures.jl @@ -8,7 +8,7 @@ using Test sig = Tuple{typeof(*), Any, Any} result1 = Core.Compiler.findall(sig, table; limit=-1) result2 = Core.Compiler.findall(sig, table; limit=Core.Compiler.InferenceParams().max_methods) - @test result1 !== nothing && !Core.Compiler.isempty(result1.matches) + @test result1 !== nothing && !Core.Compiler.isempty(result1) @test result2 === nothing end