Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility with Zygote #268

Open
daniloefl opened this issue Oct 17, 2020 · 5 comments
Open

Incompatibility with Zygote #268

daniloefl opened this issue Oct 17, 2020 · 5 comments

Comments

@daniloefl
Copy link

daniloefl commented Oct 17, 2020

Dear ArrayFire developers,

it seems that ArrayFire.jl has a small compatibility issue with Zygote.

A test example follows in [1], but the core of the issue is at the fact that Zygote implements a function Zygote.accum, which just sums up gradients, and for AbstractArrays, it is defined as follows in [2]. It basically uses broadcasting to call itself, assuming it would call the non-AbstractArray-typed version of itself. Unfortunately, the ArrayFire broadcasting calls the same function with arguments still as AbstractArray, causing an endless loop.

The solution could be a simple override of this function for AFArray:

Zygote.accum(x::AFArray, y::AFArray) =
         x === nothing ? y :
         y === nothing ? x :
         x .+ y

With this override, it all works. I am not sure if other overrides are necessary in more general cases, though. Although the Zygote developers could be summoned here, this would create a dependency between Zygote and ArrayFire, which is not really necessary. I am not sure that there is a cleaner way of solving the issue.

Best regards,
Danilo

[1]
Test example:

using ArrayFire
using Flux
using DiffEqFlux
using Zygote

hyper = FastChain(FastDense(1, 10, tanh), FastDense(10, 10, tanh), FastDense(10, 16, tanh))
p = initial_params(hyper)
x = rand(Float32, 1, 100)

# This is require due to a separate indexing issue in DiffEqFlux (unrelated to this bug, not doing this override causes a crash due to another incompatibility, but I daresay this is an issue in DiffEqFlux):
DiffEqFlux.applychain(fs::Tuple, x, p) = DiffEqFlux.applychain(Base.tail(fs), first(fs)(x,p[1:DiffEqFlux.paramlength(first(fs))]), length(fs) > 1 ? p[(DiffEqFlux.paramlength(first(fs))+1):end] : Tuple{}())

af_p = AFArray(p)

# this works:
hyper(x, af_p)

# this does not
gs = Flux.gradient(params(af_p)) do
         sum(hyper(x, af_p))
         end

The error I get is:

julia> gs = Flux.gradient(params(af_p)) do
                sum(hyper(x, af_p))
                end
ERROR: StackOverflowError:
Stacktrace:
 [1] broadcasted(::Function, ::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/ArrayFire/U0hth/src/array.jl:217
 [2] accum(::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/lib/lib.jl:16
 [3] broadcasted(::Function, ::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/ArrayFire/U0hth/src/array.jl:220
 ... (the last 2 lines are repeated 16335 more times)
 [32674] accum(::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/lib/lib.jl:16
 [32675] applychain at ./REPL[7]:2 [inlined]
 [32676] (::typeof(∂(applychain)))(::FillArrays.Fill{Float32,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 ... (the last 2 lines are repeated 1 more time)
 [32679] FastChain at /home/daniloefl/.julia/packages/DiffEqFlux/8UHw5/src/fast_layers.jl:21 [inlined]
 [32680] (::typeof(∂(λ)))(::FillArrays.Fill{Float32,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [32681] #5 at ./REPL[16]:2 [inlined]
 [32682] (::typeof(∂(#5)))(::Float32) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [32683] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(#5))})(::Float32) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177
 [32684] gradient(::Function, ::Zygote.Params) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54

[2]
From Zygote.jl/src/lib/lib.jl:

accum() = nothing
accum(x) = x

accum(x, y) =
  x === nothing ? y :
  y === nothing ? x :
  x + y

accum(x, y, zs...) = accum(accum(x, y), zs...)

accum(x::Tuple, y::Tuple) = accum.(x, y)
accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)


@daniloefl
Copy link
Author

NB: The other issue I mention in the example code comment above seems to be related with DiffEqFlux.jl itself and I have already filed an issue with their developers here:

SciML/DiffEqFlux.jl#436

@ghost
Copy link

ghost commented Oct 17, 2020

Unfortunately, the ArrayFire broadcasting calls the same function with arguments still as AbstractArray, causing an endless loop.

This is a leftover from pre v0.7 Julia broadcast days: ArrayFire is translating a broadcast into simple function calls, so exp.(afarray) is just calling exp(afarray) which in turn calls C af_exp. Proper fix would be to re-implement how broadcast is done.

@daniloefl
Copy link
Author

Hello @GAIKA

I am using Julia 1.5.2.
I don't understand your last comment. Can you tell me which steps you suggest?

Best regards
Danilo

@ghost
Copy link

ghost commented Oct 17, 2020

The bug is here: https://github.com/JuliaGPU/ArrayFire.jl/blob/master/src/array.jl#L217-L233

If you can fix it so broadcast goes directly in to C code then your Zygote (and whole other similar issues) will be gone.

@daniloefl
Copy link
Author

The issue is that I don't actually want it to go straight to C in general (maybe for exp it makes sense, but here this is not the case). Zygote.accum is defined in Julia and also not in ArrayFire. Furthermore, I don't see anything wrong with its definition:
https://github.com/FluxML/Zygote.jl/blob/master/src/lib/lib.jl#L5

I am not sure what the best solution is, but it seems to me that there are two very different cases that need to be handled separately: 1) Broadcasts that are done within the GPU with an internal ArrayFire function (such as exp.(A)); and 2) Broadcasts implemented in Julia, which have nothing to do with an already implemented ArrayFire-function in general (such as the one in Zygote).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant