Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More alloc cache improvements #583

Merged
merged 6 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
check_eltype(T)
maxsize = prod(dims) * sizeof(T)

return GPUArrays.cached_alloc((JLArray, T, dims)) do
ref = GPUArrays.cached_alloc((JLArray, maxsize)) do
data = Vector{UInt8}(undef, maxsize)
ref = DataRef(data) do data
DataRef(data) do data
resize!(data, 0)
end
obj = new{T, N}(ref, 0, dims)
finalizer(unsafe_free!, obj)
return obj
end::JLArray{T, N}
end

obj = new{T, N}(ref, 0, dims)
finalizer(unsafe_free!, obj)
return obj
end

# low-level constructor for wrapping existing data
Expand Down
25 changes: 17 additions & 8 deletions src/host/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,20 @@ end

# per-object state, with a flag to indicate whether the object has been freed.
# this is to support multiple calls to `unsafe_free!` on the same object,
# while only lowering the referene count of the underlying data once.
# while only lowering the reference count of the underlying data once.
mutable struct DataRef{D}
rc::RefCounted{D}
freed::Bool
cached::Bool
end

function DataRef(finalizer, data::D) where {D}
rc = RefCounted{D}(data, finalizer, Threads.Atomic{Int}(1))
DataRef{D}(rc, false)
function DataRef(finalizer, ref::D) where {D}
rc = RefCounted{D}(ref, finalizer, Threads.Atomic{Int}(1))
DataRef{D}(rc, false, false)
end
DataRef(data; kwargs...) = DataRef(nothing, data; kwargs...)
DataRef(ref; kwargs...) = DataRef(nothing, ref; kwargs...)

Base.sizeof(ref::DataRef) = sizeof(ref.rc[])

function Base.getindex(ref::DataRef)
if ref.freed
Expand All @@ -77,18 +80,24 @@ function Base.copy(ref::DataRef{D}) where {D}
throw(ArgumentError("Attempt to copy a freed reference."))
end
retain(ref.rc)
return DataRef{D}(ref.rc, false)
# copies of cached references are not managed by the cache, so
# we need to mark them as such to make sure their refcount can drop.
return DataRef{D}(ref.rc, false, false)
end

function unsafe_free!(ref::DataRef, args...)
function unsafe_free!(ref::DataRef)
if ref.cached
# lifetimes of cached references are tied to the cache.
return
end
if ref.freed
# multiple frees *of the same object* are allowed.
# we should only ever call `release` once per object, though,
# as multiple releases of the underlying data is not allowed.
return
end
ref.freed = true
release(ref.rc, args...)
release(ref.rc)
return
end

Expand Down
70 changes: 36 additions & 34 deletions src/host/alloc_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ end

mutable struct AllocCache
lock::ReentrantLock
busy::Dict{UInt64, Vector{Any}} # hash(key) => GPUArray[]
free::Dict{UInt64, Vector{Any}}
busy::Dict{UInt64, Vector{DataRef}}
free::Dict{UInt64, Vector{DataRef}}

function AllocCache()
cache = new(
Expand All @@ -24,43 +24,48 @@ end
function get_pool!(cache::AllocCache, pool::Symbol, uid::UInt64)
pool = getproperty(cache, pool)
uid_pool = get(pool, uid, nothing)
if uid_pool nothing
uid_pool = Base.@lock cache.lock pool[uid] = Any[]
if uid_pool === nothing
uid_pool = pool[uid] = DataRef[]
end
return uid_pool
end

function cached_alloc(f, key)
cache = ALLOC_CACHE[]
if cache === nothing
return f()::AbstractGPUArray
return f()::DataRef
end

x = nothing
ref = nothing
uid = hash(key)

busy_pool = get_pool!(cache, :busy, uid)
free_pool = get_pool!(cache, :free, uid)
isempty(free_pool) && (x = f()::AbstractGPUArray)
Base.@lock cache.lock begin
free_pool = get_pool!(cache, :free, uid)

if !isempty(free_pool)
ref = Base.@lock cache.lock pop!(free_pool)
end
end

while !isempty(free_pool) && x ≡ nothing
tmp = Base.@lock cache.lock pop!(free_pool)
# Array was manually freed via `unsafe_free!`.
GPUArrays.storage(tmp).freed && continue
x = tmp
if ref === nothing
ref = f()::DataRef
ref.cached = true
end

x ≡ nothing && (x = f()::AbstractGPUArray)
Base.@lock cache.lock push!(busy_pool, x)
return x
Base.@lock cache.lock begin
busy_pool = get_pool!(cache, :busy, uid)
push!(busy_pool, ref)
end

return ref
end

function free_busy!(cache::AllocCache)
for uid in cache.busy.keys
busy_pool = get_pool!(cache, :busy, uid)
isempty(busy_pool) && continue
Base.@lock cache.lock begin
for uid in keys(cache.busy)
busy_pool = get_pool!(cache, :busy, uid)
isempty(busy_pool) && continue

Base.@lock cache.lock begin
free_pool = get_pool!(cache, :free, uid)
append!(free_pool, busy_pool)
empty!(busy_pool)
Expand All @@ -71,14 +76,13 @@ end

function unsafe_free!(cache::AllocCache)
Base.@lock cache.lock begin
for (_, pool) in cache.busy
isempty(pool) || error(
"Invalidating allocations cache that's currently in use. " *
"Invalidating inside `@cached` is not allowed."
)
for pool in values(cache.busy)
isempty(pool) || error("Cannot invalidate a cache that's in active use")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe check that the pool is empty and that the cache is in use (reading ScopedValue). Otherwise, if an exception happens during @cached it will raise this in the cache finalizer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds questionable. Shouldn't @cached instead have a try/catch (well, the scope-less variety) that wipes the cache in the finally block, making sure the cache is always in a known-good state outside of the @cached block?

end
for (_, pool) in cache.free
map(unsafe_free!, pool)
for pool in values(cache.free), ref in pool
# release the reference
ref.cached = false
unsafe_free!(ref)
end
empty!(cache.free)
end
Expand Down Expand Up @@ -143,13 +147,11 @@ GPUArrays.unsafe_free!(cache)
See [`@uncached`](@ref).
"""
macro cached(cache, expr)
try_expr = :(@with $(esc(ALLOC_CACHE)) => cache $(esc(expr)))
fin_expr = :(free_busy!($(esc(cache))))
return quote
cache = $(esc(cache))
GC.@preserve cache begin
res = @with $(esc(ALLOC_CACHE)) => cache $(esc(expr))
free_busy!(cache)
res
end
local cache = $(esc(cache))
GC.@preserve cache $(Expr(:tryfinally, try_expr, fin_expr))
end
end

Expand Down
84 changes: 70 additions & 14 deletions test/testsuite/alloc_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,98 @@
if AT <: AbstractGPUArray
cache = GPUArrays.AllocCache()

# first allocation populates the cache
T, dims = Float32, (1, 2, 3)
GPUArrays.@cached cache begin
x1 = AT(zeros(T, dims))
cached1 = AT(zeros(T, dims))
end
@test sizeof(cache) == sizeof(T) * prod(dims)
@test sizeof(cache) == sizeof(cached1)
key = first(keys(cache.free))
@test length(cache.free[key]) == 1
@test length(cache.busy[key]) == 0
@test x1 === cache.free[key][1]
@test cache.free[key][1] === GPUArrays.storage(cached1)

# Second allocation hits cache.
# second allocation hits the cache
GPUArrays.@cached cache begin
x2 = AT(zeros(T, dims))
# Does not hit the cache.
GPUArrays.@uncached x_free = AT(zeros(T, dims))
cached2 = AT(zeros(T, dims))

# explicitly uncached ones don't
GPUArrays.@uncached uncached = AT(zeros(T, dims))
end
@test sizeof(cache) == sizeof(cached2)
key = first(keys(cache.free))
@test length(cache.free[key]) == 1
@test length(cache.busy[key]) == 0
@test cache.free[key][1] === GPUArrays.storage(cached2)
@test uncached !== cached2

# compatible shapes should also hit the cache
dims = (3, 2, 1)
GPUArrays.@cached cache begin
cached3 = AT(zeros(T, dims))
end
@test sizeof(cache) == sizeof(T) * prod(dims)
@test sizeof(cache) == sizeof(cached3)
key = first(keys(cache.free))
@test length(cache.free[key]) == 1
@test length(cache.busy[key]) == 0
@test x2 === cache.free[key][1]
@test x_free !== x2
@test cache.free[key][1] === GPUArrays.storage(cached3)

# Third allocation is of different shape - allocates.
# as should compatible eltypes
T = Int32
GPUArrays.@cached cache begin
cached4 = AT(zeros(T, dims))
end
@test sizeof(cache) == sizeof(cached4)
key = first(keys(cache.free))
@test length(cache.free[key]) == 1
@test length(cache.busy[key]) == 0
@test cache.free[key][1] === GPUArrays.storage(cached4)

# different shapes should trigger a new allocation
dims = (2, 2)
GPUArrays.@cached cache begin
x3 = AT(zeros(T, dims))
cached5 = AT(zeros(T, dims))

# we're allowed to early free arrays, which should be a no-op for cached data
GPUArrays.unsafe_free!(cached5)
end
@test sizeof(cache) == sizeof(cached4) + sizeof(cached5)
_keys = collect(keys(cache.free))
key2 = _keys[findfirst(i -> i != key, _keys)]
@test length(cache.free[key]) == 1
@test length(cache.free[key2]) == 1
@test x3 === cache.free[key2][1]
@test cache.free[key2][1] === GPUArrays.storage(cached5)

# we should be able to re-use the early-freed
GPUArrays.@cached cache begin
cached5 = AT(zeros(T, dims))
end

# exceptions shouldn't cause issues
@test_throws "Allowed exception" GPUArrays.@cached cache begin
AT(zeros(T, dims))
error("Allowed exception")
end
# NOTE: this should remaint the last test before calling `unsafe_free!` below,
# as it caught an erroneous assertion in the original code.

# Freeing all memory held by cache.
# freeing all memory held by cache should free all allocations
@test !GPUArrays.storage(cached1).freed
@test GPUArrays.storage(cached1).cached
@test !GPUArrays.storage(cached5).freed
@test GPUArrays.storage(cached5).cached
@test !GPUArrays.storage(uncached).freed
@test !GPUArrays.storage(uncached).cached
GPUArrays.unsafe_free!(cache)
@test sizeof(cache) == 0
@test GPUArrays.storage(cached1).freed
@test !GPUArrays.storage(cached1).cached
@test GPUArrays.storage(cached5).freed
@test !GPUArrays.storage(cached5).cached
@test !GPUArrays.storage(uncached).freed
## test that the underlying data was freed as well
@test GPUArrays.storage(cached1).rc.count[] == 0
@test GPUArrays.storage(cached5).rc.count[] == 0
@test GPUArrays.storage(uncached).rc.count[] == 1
end
end
Loading