Skip to content

Commit

Permalink
Taking world ages seriously (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Mar 13, 2023
1 parent 4e98899 commit 860ec6a
Show file tree
Hide file tree
Showing 19 changed files with 516 additions and 297 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GPUCompiler"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
authors = ["Tim Besard <[email protected]>"]
version = "0.17.3"
version = "0.18.0"

[deps]
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expand Down
4 changes: 2 additions & 2 deletions examples/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntim
kernel() = nothing

function main()
source = FunctionSpec(typeof(kernel))
source = FunctionSpec(typeof(kernel), Tuple{})
target = NativeCompilerTarget()
params = TestCompilerParams()
job = CompilerJob(target, source, params)
job = CompilerJob(source, target, params)

println(GPUCompiler.compile(:asm, job)[1])
end
Expand Down
146 changes: 114 additions & 32 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,82 @@
using Core.Compiler: retrieve_code_info, CodeInfo, MethodInstance, SSAValue, SlotNumber, ReturnNode
using Base: _methods_by_ftype

# generated function that crafts a custom code info to call the actual compiler.
# this gives us the flexibility to insert manual back edges for automatic recompilation.
# generated function that returns the world age of a compilation job. this can be used to
# drive compilation, e.g. by using it as a key for a cache, as the age will change when a
# function or any called function is redefined.


"""
get_world(ft, tt)
A special function that returns the world age in which the current definition of function
type `ft`, invoked with argument types `tt`, is defined. This can be used to cache
compilation results:
compilation_cache = Dict()
function cache_compilation(ft, tt)
world = get_world(ft, tt)
get!(compilation_cache, (ft, tt, world)) do
# compile
end
end
What makes this function special is that it is a generated function, returning a constant,
whose result is automatically invalidated when the function `ft` (or any called function) is
redefined. This makes this query ideally suited for hot code, where you want to avoid a
costly look-up of the current world age on every invocation.
Normally, you shouldn't have to use this function, as it's used by `FunctionSpec`.
!!! warning
Due to a bug in Julia, JuliaLang/julia#34962, this function's results are only
guaranteed to be correctly invalidated when the target function `ft` is executed or
processed by codegen (e.g., by calling `code_llvm`).
"""
get_world

# generate functions currently do not know which world they are invoked for, so we fall
# back to using the current world. this may be wrong when the generator is invoked in a
# different world (TODO: when does this happen?)
#
# we also increment a global specialization counter and pass it along to index the cache.

const specialization_counter = Ref{UInt}(0)
@generated function specialization_id(job::CompilerJob{<:Any,<:Any,FunctionSpec{f,tt}}) where {f,tt}
# get a hold of the method and code info of the kernel function
sig = Tuple{f, tt.parameters...}
# XXX: instead of typemax(UInt) we should use the world-age of the fspec
mthds = _methods_by_ftype(sig, -1, typemax(UInt))
# XXX: this should be fixed by JuliaLang/julia#48611

function get_world_generator(self, ::Type{Type{ft}}, ::Type{Type{tt}}) where {ft, tt}
@nospecialize

# look up the method
sig = Tuple{ft, tt.parameters...}
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results
mthds = if VERSION >= v"1.7.0-DEV.1297"
Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1,
#=world=# typemax(UInt), #=ambig=# false,
min_world, max_world, has_ambig)
# XXX: use the correct method table to support overlaying kernels
else
Base._methods_by_ftype(sig, #=lim=# -1,
#=world=# typemax(UInt), #=ambig=# false,
min_world, max_world, has_ambig)
end
# XXX: using world=-1 is wrong, but the current world isn't exposed to this generator

# check the validity of the method matches
method_error = :(throw(MethodError(ft, tt)))
mthds === nothing && return method_error
Base.isdispatchtuple(tt) || return(:(error("$tt is not a dispatch tuple")))
length(mthds) == 1 || return (:(throw(MethodError(job.source.f,job.source.tt))))
length(mthds) == 1 || return method_error

# look up the method and code instance
mtypes, msp, m = mthds[1]
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
ci = retrieve_code_info(mi)::CodeInfo

# generate a unique id to represent this specialization
# TODO: just use the lower world age bound in which this code info is valid.
# (the method instance doesn't change when called functions are changed).
# but how to get that? the ci here always has min/max world 1/-1.
# XXX: don't use `objectid(ci)` here, apparently it can alias (or the CI doesn't change?)
id = (specialization_counter[] += 1)
# XXX: we don't know the world age that this generator was requested to run in, so use
# the current world (we cannot use the mi's world because that doesn't update when
# called functions are changed). this isn't correct, but should be close.
world = Base.get_world_counter()

# prepare a new code info
new_ci = copy(ci)
Expand All @@ -34,22 +87,20 @@ const specialization_counter = Ref{UInt}(0)
resize!(new_ci.linetable, 1) # see note below
empty!(new_ci.ssaflags)
new_ci.ssavaluetypes = 0
new_ci.min_world = min_world[]
new_ci.max_world = max_world[]
new_ci.edges = MethodInstance[mi]
# XXX: setting this edge does not give us proper method invalidation, see
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
# invoking `code_llvm` also does the necessary codegen, as does calling the
# underlying C methods -- which GPUCompiler does, so everything Just Works.

# prepare the slots
new_ci.slotnames = Symbol[Symbol("#self#"), :cache, :job, :compiler, :linker]
new_ci.slotflags = UInt8[0x00 for i = 1:5]
cache = SlotNumber(2)
job = SlotNumber(3)
compiler = SlotNumber(4)
linker = SlotNumber(5)

# call the compiler
push!(new_ci.code, ReturnNode(id))
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
new_ci.slotflags = UInt8[0x00 for i = 1:3]

# return the world
push!(new_ci.code, ReturnNode(world))
push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
push!(new_ci.codelocs, 1) # see note below
new_ci.ssavaluetypes += 1
Expand All @@ -62,17 +113,48 @@ const specialization_counter = Ref{UInt}(0)
return new_ci
end

@eval function get_world(ft, tt)
$(Expr(:meta, :generated_only))
$(Expr(:meta,
:generated,
Expr(:new,
Core.GeneratedFunctionStub,
:get_world_generator,
Any[:get_world, :ft, :tt],
Any[],
@__LINE__,
QuoteNode(Symbol(@__FILE__)),
true)))
end

const cache_lock = ReentrantLock()

"""
cached_compilation(cache::Dict, job::CompilerJob, compiler, linker)
Compile `job` using `compiler` and `linker`, and store the result in `cache`.
The `cache` argument should be a dictionary that can be indexed using a `UInt` and store
whatever the `linker` function returns. The `compiler` function should take a `CompilerJob`
and return data that can be cached across sessions (e.g., LLVM IR). This data is then
forwarded, along with the `CompilerJob`, to the `linker` function which is allowed to create
session-dependent objects (e.g., a `CuModule`).
"""
function cached_compilation(cache::AbstractDict,
@nospecialize(job::CompilerJob),
compiler::Function, linker::Function)
# XXX: CompilerJob contains a world age, so can't be respecialized.
# have specialization_id take a f/tt and return a world to construct a CompilerJob?
key = hash(job, specialization_id(job))
force_compilation = compile_hook[] !== nothing
# NOTE: it is OK to index the compilation cache directly with the compilation job, i.e.,
# using a world age instead of intersecting world age ranges, because we expect
# that the world age is aquired through calling `get_world` and thus will only
# ever change when the kernel function is redefined.
#
# if we ever want to be able to index the cache using a compilation job that
# contains a more recent world age, yet still return an older cached object that
# would still be valid, we'd need the cache to store world ranges instead and
# use an invalidation callback to add upper bounds to entries.
key = hash(job)

# XXX: by taking the hash, we index the compilation cache directly with the world age.
# that's wrong; we should perform an intersection with the entry its bounds.
force_compilation = compile_hook[] !== nothing

# NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean
lock(cache_lock)
Expand Down
31 changes: 25 additions & 6 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,13 @@ end

# get the method instance
sig = typed_signature(job)
meth = which(sig)
meth = if VERSION >= v"1.10.0-DEV.65"
Base._which(sig; world=job.source.world).method
elseif VERSION >= v"1.7.0-DEV.435"
Base._which(sig, job.source.world).method
else
ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), sig, job.source.world)
end

(ti, env) = ccall(:jl_type_intersection_with_env, Any,
(Any, Any), sig, meth.sig)::Core.SimpleVector
Expand All @@ -175,6 +181,10 @@ end
end
end

# ensure that the returned method instance is valid in the compilation world.
# otherwise, `jl_create_native` won't actually emit any code.
@assert method_instance.def.primary_world <= job.source.world <= method_instance.def.deleted_world

return method_instance, ()
end

Expand All @@ -189,9 +199,9 @@ Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
ptr
end

@generated function deferred_codegen(::Val{f}, ::Val{tt}) where {f,tt}
@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
id = length(deferred_codegen_jobs) + 1
deferred_codegen_jobs[id] = FunctionSpec(f,tt)
deferred_codegen_jobs[id] = FunctionSpec(ft, tt)

pseudo_ptr = reinterpret(Ptr{Cvoid}, id)
quote
Expand Down Expand Up @@ -286,10 +296,19 @@ const __llvm_initialized = Ref(false)
id = convert(Int, first(operands(call)))

global deferred_codegen_jobs
dyn_job = deferred_codegen_jobs[id]
if dyn_job isa FunctionSpec
dyn_job = similar(job, dyn_job)
dyn_val = deferred_codegen_jobs[id]

# get a job in the appopriate world
dyn_job = if dyn_val isa CompilerJob
dyn_spec = FunctionSpec(dyn_val.source; world=job.source.world)
CompilerJob(dyn_val; source=dyn_spec)
elseif dyn_val isa FunctionSpec
dyn_spec = FunctionSpec(dyn_val; world=job.source.world)
CompilerJob(job; source=dyn_spec)
else
error("invalid deferred job type $(typeof(dyn_val))")
end

push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
end

Expand Down
Loading

0 comments on commit 860ec6a

Please sign in to comment.