-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Broadcast does not work with wrappers (Adjoint, Transpose, SubArray) #147
Comments
Hi, I am working on a project that aims to explore the possibility of GPU (CUDA in particular) programming with Julia, using libraries including GPUArrays, CuArrays, etc. It will be really helpful if operations like Regarding wrappers that can take GPUArray as a backing array and expose itself as an julia> using CuArrays
julia> CuArrays.allowscalar(false)
false
julia> xs = cu([1]);
julia> ys = transpose(xs)
1×1 LinearAlgebra.Transpose{Float32,CuArray{Float32,1}}:
1.0
julia> fill!(ys,0)
ERROR: scalar setindex! is disabled
julia> zs = similar(xs);
julia> copyto!(zs,ys)
ERROR: scalar getindex is disabled
julia> collect(ys)
ERROR: scalar getindex is disabled
julia> Base.print_array(Base.stdout,ys)
1.0f0 The reason that Base.print_array(io::IO, x::LinearAlgebra.Transpose{<:Any,<:GPUArray}) = Base.print_array(io, LinearAlgebra.transpose(collect(x.parent))) Unfortunately these hacks are not defined for arbitrary wrappers, so e.g. This affects practically any array wrappers with vector functions that access its content, e.g. The problem is, a wrapper struct with However, as I understand, since If we want to tell the compiler that some arrays with type like Sunny |
Probably |
Still needs tests, I'll leave this open until that is fixed. |
Matt pointed out that the storage style traits JuliaLang/julia#25558 was one attempt of solving a related issue. |
I was playing with my toy example and see how traits address this problem. I figure that the tricky bit is the recursive case. I've adapted my toy example with Using trait-like technique, such as The problem is that we don't have a way to express a union type like EDIT: maybe we can, I'm experimenting with something like this. Hopefully the computation is done at compile time. @traitdef IsOnGPU{A}
# helpers
const bottoms = (MyStruct, MyStruct2)
const wrappers = (MyWrapper, MyWrapper2)
is_bottom(x::TypeVar) = false
is_bottom(x::Type) = any(t->x<:t,bottoms)
is_wrapper(x::Type) = any(t->x<:t,wrappers)
unwrap_type(x::Type{MyWrapper{T,AT}}) where {T,AT} = x.parameters[2]
unwrap_type(x::Type{MyWrapper2{T,AT}}) where {T,AT} = x.parameters[2]
Base.@pure SimpleTraits.trait(::Type{IsOnGPU{X}}) where {X} = begin
local Y = X
while is_wrapper(Y)
Y = unwrap_type(Y)
end
is_bottom(Y) ? IsOnGPU{X} : Not{IsOnGPU{X}}
end The problem is non-existent in the separate-wrapper approach (as shown in my previous toy example), The key idea here is to override constructors to propagate the information via constructors, by injecting additional traits implementations on the fly. One can notice that we are special casing all the wrapper types we want to support. To fix(?) this we need to introduce another layer
Toy example: using SimpleTraits
abstract type MyInterface{T} end # like AbstractArray
abstract type MyGPUStruct{T} <: MyInterface{T} end # like GPUArray
struct MyStruct{T} <: MyGPUStruct{T} # like CuArray
name::String
x::T
end
struct MyStruct2{T} <: MyGPUStruct{T} # like JLArray, uses fall-back implementations
name::String
x::T
end
struct MyWrapper{T,AT <: MyInterface{T}} <: MyInterface{T} # like LinearAlgebra.Transpose
parent::AT
end
struct MyWrapper2{T,AT <: MyInterface{T}} <: MyInterface{T} # any Wrapper that we want it to fall back to base
parent2::AT
end
# methods in Base to create wrappers, like view, transpose and adjoint
wrap1(x::MyInterface) = MyWrapper(x)
wrap2(x::MyInterface) = MyWrapper2(x)
# overridden methods in `Base`
function get_name(::MyInterface)
println("base fallback")
"fallback"
end
# a method that does expect a wrapper type directly
get_name(o::MyStruct) = begin println("MyStruct"); o.name end
get_name(o::MyStruct2) = begin println("MyStruct2"); o.name end
get_name(o::MyWrapper) = begin println("Wrapper"); get_name(o.parent) end
get_name(o::MyWrapper2) = begin println("Wrapper2"); get_name(o.parent2) end
# a method that does not expect a wrapper type directly
foo(o::MyInterface) = println("generalised foo")
# A trait that marks an array as GPU-compatible
@traitdef IsOnGPU{A}
@traitimpl IsOnGPU{MyStruct}
# Need to put the backend information elsewhere
get_backend(::Type{MyStruct}) = MyStruct
get_backend(::Type{MyStruct{T}}) where {T} = MyStruct
@traitimpl IsOnGPU{MyStruct2}
get_backend(::Type{MyStruct2}) = MyStruct2
get_backend(::Type{MyStruct2{T}}) where {T} = MyStruct2
# need this to be able to invoke the overridden constructor
GPUTarget{T} = Union{MyGPUStruct{T},MyWrapper{T}}
# TODO: maybe use macro to help generate these specialised constructors
# I'm using GPUTarget (I think it's called GPUDestArray right now) to specialise
# on potentially GPU-compatible array types (to avoid infinite loop problem),
# but use the IsOnGPU trait to pick out the actually GPU-compatible ones.
# The tricky bit: handling recursive case is problematic because there is no easy way
# to express the union type:
#
# {MyStruct, MyWrapper{MyStruct}, MyWrapper{MyWrapper{MyStruct}}, ...}
# i.e. the least fixed point of `f`, `f_*`, in a poset with bottom element, `MyStruct`, where
#
# f_0 = MyStruct
# f_1 = f (f_0) = Union{MyStruct, MyWrapper{MyStruct}}
# f_2 = f (f_1) = Union{MyStruct, MyWrapper{MyStruct}, MyWrapper{MyWrapper{MyStruct}}}
# ...
# f_* is the type we need.
#
# The idea here is to assert that a type with trait `IsOnGPU` is within this union type,
# and we preserve this property by construction, i.e. overriding constructors to
# make sure only GPU-compatible types have trait `IsOnGPU`. The idea is similar to using
# a separate wrapper `OnGPU`, but without messing up the type hierarchy and
# other methods that we are not interested in overriding.
#
# What I do here is to dynamically inject trait implementation, backend information, as well as
# specialised method declarations e.g.`get_name`, into the dispatch table. So instead of
# storing the lost information in the type signature, we are effectively
# storing the information in the method dispatch table by using traits.
#
# If we want to override e.g. foo(::MyWrapper{T}), but also calling it within our
# overriding method e.g. foo(::A) where {A <: MyWrapper; IsOnGPU{A}}, we need to be careful
# to avoid getting infinite loops. Also julia traits without injecting methods won't work for these
# methods that expect a wrapper type directly, see `get_name`` for example.
# Though I don't know how useful it is to override such methods.
# A problem is that that we need to make sure the injected method actually gets compiled and becomes
# visible. Try removing `warmup()` below and watch everything explodes.
#
# I don't know if this approach is the right way to do it. And even if it does work, how it affects performance
# remains an important question to ask.
# Override constructor to allow propagation of information
@traitfn MyWrapper(x::A) where {T, A<:GPUTarget{T}; IsOnGPU{A}} = begin
println("invoking hijacked wrapper constructor")
y = invoke(MyWrapper,Tuple{MyInterface{T}},x)
BE = get_backend(A)
WA = typeof(y)
# allow nested wrapper to propagate information
# seems hacky to myself
@eval @traitimpl IsOnGPU{$WA}
@eval @inline get_backend(::Type{$WA}) = $BE
# dynamically inject specialised methods to
# hijack methods that expect a concrete wrapper directly
# need to be very careful to avoid getting infinite loops
@eval @inline get_name(x::$WA) = get_name_gpu($BE, x)
y
end
# can use traits functions to deal with methods not directly
# expecting a concrete wrapper instance
@traitfn foo(x::A) where {T, A<:GPUTarget{T}; IsOnGPU{A}} = begin
println("GPU specialised foo")
end
@traitfn foo(x::A) where {T, A<:GPUTarget{T}; !IsOnGPU{A}} = begin
invoke(foo,Tuple{MyInterface},x)
end
# fall back to base
@inline get_name_gpu(BE::Type{T}, x::MyStruct) where {T<:MyGPUStruct} = invoke(get_name,Tuple{MyStruct},x)
@inline get_name_gpu(BE::Type{T}, x::MyWrapper) where {T<:MyGPUStruct} = invoke(get_name,Tuple{MyWrapper},x)
get_name_gpu(BE::Type{MyStruct}, x::MyWrapper) = begin println("GPU version"); get_name_gpu(BE, x.parent) end
get_name_gpu(::Type{MyStruct}, x::MyStruct) = get_name(x)
# force compilation
function warmup()
Base.invokelatest(wrap1,wrap1(MyStruct("",1)))
Base.invokelatest(wrap1,wrap1(MyStruct2("",2)))
end
warmup()
function test()
begin
x = MyStruct("x",1) # uses gpu method
y = wrap1(x) # dispatch to gpu method
z = wrap1(y) # dispatch to gpu method
w = wrap2(z) # fall back to base for non-compatible wrappers
v = wrap1(w) # fall back to base again
get_name(x)
get_name(y)
get_name(z)
get_name(w)
get_name(v)
foo(x)
foo(y)
foo(v)
end
begin
x = MyStruct2("x",1) # backend without specialised get_name implementation
y = wrap1(x) # use fall back methods for all below
z = wrap1(y)
w = wrap2(z)
v = wrap1(w)
get_name(x)
get_name(y)
get_name(z)
get_name(w)
get_name(v)
foo(x) # this GPU implementation is in GPUArray, thus is used here as a fall-back
foo(y)
foo(v)
end
end
test() |
Interesting I have to look at your code in detail later. In the meantime I have been experimenting a bit (https://github.com/JuliaGPU/GPUArrays.jl/tree/vc/explore_bc), but I don't see a way for us to solve this in general without improving Base Julia first. |
Yes, I agree that solving this in general for arbitrary wrappers is impossible without modifying the base type hierarchy, since currently there is no consistent way (even for a human) of telling if an arbitrary EDIT: I don't know if it makes sense to lookup wrapper types using reflection at module loading time, like find_type_name(t::DataType) = [t.name]
find_type_name(t::Union) = cat(find_type_name(t.a), find_type_name(t.b); dims = 1)
find_type_name(t::UnionAll) = find_type_name(t.body)
const wrappers = Set{Core.TypeName}()
for m in Base.methods(Base.parent)
push!(wrappers, find_type_name(m.sig.parameters[2])...)
end EDIT2: Never mind, julia method type signature doesn't contain the return type so it's not gonna help us actually finding type of the parent array. 😞 I don't know if there is a consistent way of finding the parent array type given a wrapper array type. For now I guess we have to special treat common wrapper types and let users add support for user-defined wrappers. |
More experiment on toy example with traits It seems that we might be able to also override methods directly expecting a wrapper type without overriding wrapper constructors, if the wrapper type is parametric and contains its parent array type as one of parameters, i.e. we can still use it as dispatch target by fiddling with its parent array type parameter. It is quite important to test the performance loss of calling I think we can go a long way in implementing specific wrapper support, e.g. |
Some more partial code on supporting slightly more generic wrapper types by using reflection and type inference hacks. Warning: very hacky! |
This works nowadays, by using Adapt.jl. |
@MikeInnes @vchuravy You guys have been dabbling with broadcast recently, any idea?
I also want to get this working for subarrays, which require the following additional changes:
The text was updated successfully, but these errors were encountered: