Skip to content

Commit

Permalink
Fix method error generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Feb 22, 2023
1 parent 9017c35 commit 4d91a45
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 31 deletions.
46 changes: 24 additions & 22 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,31 @@ using Base: _methods_by_ftype


"""
get_world(f, tt)
get_world(ft, tt)
A special function that returns the world age in which the current definition of function
`f`, invoked with argument types `tt`, is defined. This can be used to cache compilation
results:
type `ft`, invoked with argument types `tt`, is defined. This can be used to cache
compilation results:
compilation_cache = Dict()
function cache_compilation(f, tt)
world = get_world(f, tt)
get!(compilation_cache, (f, tt, world)) do
function cache_compilation(ft, tt)
world = get_world(ft, tt)
get!(compilation_cache, (ft, tt, world)) do
# compile
end
end
What makes this function special is that it is a generated function, returning a constant,
whose result is automatically invalidated when the function `f` (or any called function)
is redefined. This makes this query ideally suited for hot code, where you want to avoid
a costly look-up of the current world age on every invocation.
whose result is automatically invalidated when the function `ft` (or any called function) is
redefined. This makes this query ideally suited for hot code, where you want to avoid a
costly look-up of the current world age on every invocation.
Normally, you shouldn't have to use this function, as it's used by `FunctionSpec`.
!!! warning
Due to a bug in Julia, JuliaLang/julia#34962, this function's results are only
guaranteed to be correctly invalidated when the target function `f` is executed or
guaranteed to be correctly invalidated when the target function `ft` is executed or
processed by codegen (e.g., by calling `code_llvm`).
"""
get_world
Expand All @@ -41,15 +43,15 @@ if VERSION >= v"1.10.0-DEV.649"
# on 1.10 (JuliaLang/julia#48611) the generated function knows which world it was invoked in

function _generated_ex(world, source, ex)
stub = Core.GeneratedFunctionStub(identity, Core.svec(:get_world, :job), Core.svec())
stub = Core.GeneratedFunctionStub(identity, Core.svec(:get_world, :ft, :tt), Core.svec())
stub(world, source, ex)
end

function get_world_generator(world::UInt, source, self, ::Type{Type{f}}, ::Type{Type{tt}}) where {f, tt}
function get_world_generator(world::UInt, source, self, ::Type{Type{ft}}, ::Type{Type{tt}}) where {ft, tt}
@nospecialize

# look up the method
sig = Tuple{f, tt.parameters...}
sig = Tuple{ft, tt.parameters...}
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results
Expand All @@ -65,7 +67,7 @@ function get_world_generator(world::UInt, source, self, ::Type{Type{f}}, ::Type{
end

# check the validity of the method matches
method_error = :(throw(MethodError(f, tt, $world)))
method_error = :(throw(MethodError(ft, tt, $world)))
mthds === nothing && return _generated_ex(world, source, method_error)
Base.isdispatchtuple(tt) || return _generated_ex(world, source, :(error("$tt is not a dispatch tuple")))
length(mthds) == 1 || return _generated_ex(world, source, method_error)
Expand All @@ -91,7 +93,7 @@ function get_world_generator(world::UInt, source, self, ::Type{Type{f}}, ::Type{
# underlying C methods -- which GPUCompiler does, so everything Just Works.

# prepare the slots
new_ci.slotnames = Symbol[Symbol("#self#"), :f, :tt]
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
new_ci.slotflags = UInt8[0x00 for i = 1:3]

# return the world
Expand All @@ -108,7 +110,7 @@ function get_world_generator(world::UInt, source, self, ::Type{Type{f}}, ::Type{
return new_ci
end

@eval function get_world(f, tt)
@eval function get_world(ft, tt)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, get_world_generator))
end
Expand All @@ -118,11 +120,11 @@ else
# on older versions of Julia we fall back to looking up the current world. this may be wrong
# when the generator is invoked in a different world (TODO: when does this happen?)

function get_world_generator(self, ::Type{Type{f}}, ::Type{Type{tt}}) where {f, tt}
function get_world_generator(self, ::Type{Type{ft}}, ::Type{Type{tt}}) where {ft, tt}
@nospecialize

# look up the method
sig = Tuple{f, tt.parameters...}
sig = Tuple{ft, tt.parameters...}
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results
Expand All @@ -139,7 +141,7 @@ function get_world_generator(self, ::Type{Type{f}}, ::Type{Type{tt}}) where {f,
# XXX: using world=-1 is wrong, but the current world isn't exposed to this generator

# check the validity of the method matches
method_error = :(throw(MethodError(f, tt)))
method_error = :(throw(MethodError(ft, tt)))
mthds === nothing && return method_error
Base.isdispatchtuple(tt) || return(:(error("$tt is not a dispatch tuple")))
length(mthds) == 1 || return method_error
Expand Down Expand Up @@ -170,7 +172,7 @@ function get_world_generator(self, ::Type{Type{f}}, ::Type{Type{tt}}) where {f,
# underlying C methods -- which GPUCompiler does, so everything Just Works.

# prepare the slots
new_ci.slotnames = Symbol[Symbol("#self#"), :f, :tt]
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
new_ci.slotflags = UInt8[0x00 for i = 1:3]

# return the world
Expand All @@ -187,14 +189,14 @@ function get_world_generator(self, ::Type{Type{f}}, ::Type{Type{tt}}) where {f,
return new_ci
end

@eval function get_world(f, tt)
@eval function get_world(ft, tt)
$(Expr(:meta, :generated_only))
$(Expr(:meta,
:generated,
Expr(:new,
Core.GeneratedFunctionStub,
:get_world_generator,
Any[:get_world, :f, :tt],
Any[:get_world, :ft, :tt],
Any[],
@__LINE__,
QuoteNode(Symbol(@__FILE__)),
Expand Down
22 changes: 13 additions & 9 deletions src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ function return_type(m::Core.MethodMatch; interp::AbstractInterpreter)
return something(ty, Any)
end

# create a MethodError from a function type
# TODO: fix upstream
function MethodError(ft::Type, tt::Type, world::Integer=typemax(UInt))
f = if isdefined(ft, :instance)
ft.instance
else
# HACK: dealing with a closure or something... let's do somthing really invalid,
# which works because MethodError doesn't actually use the function
Ref{ft}()[]
end
Base.MethodError(f, tt, world)
end

function check_method(@nospecialize(job::CompilerJob))
isa(job.source.ft, Core.Builtin) &&
Expand All @@ -26,15 +38,7 @@ function check_method(@nospecialize(job::CompilerJob))
# get the method
ms = method_matches(typed_signature(job); job.source.world)
if length(ms) != 1
# we only have a function type, but MethodError needs an instance...
f = if isdefined(job.source.ft, :instance)
job.source.ft.instance
else
# HACK: dealing with a closure or something... let's do somthing really invalid,
# which works because MethodError doesn't actually use the function
Ref{job.source.ft}()[]
end
throw(MethodError(f, job.source.tt, job.source.world))
throw(MethodError(job.source.ft, job.source.tt, job.source.world))
end

# kernels can't return values
Expand Down

0 comments on commit 4d91a45

Please sign in to comment.