-
-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incompatibility with Zygote #268
Comments
NB: The other issue I mention in the example code comment above seems to be related with |
This is a leftover from pre v0.7 Julia broadcast days: ArrayFire is translating a broadcast into simple function calls, so |
Hello @GAIKA I am using Julia 1.5.2. Best regards |
The bug is here: https://github.com/JuliaGPU/ArrayFire.jl/blob/master/src/array.jl#L217-L233 If you can fix it so broadcast goes directly in to C code then your Zygote (and whole other similar issues) will be gone. |
The issue is that I don't actually want it to go straight to C in general (maybe for I am not sure what the best solution is, but it seems to me that there are two very different cases that need to be handled separately: 1) Broadcasts that are done within the GPU with an internal ArrayFire function (such as |
Dear ArrayFire developers,
it seems that ArrayFire.jl has a small compatibility issue with Zygote.
A test example follows in [1], but the core of the issue is at the fact that Zygote implements a function
Zygote.accum
, which just sums up gradients, and forAbstractArrays
, it is defined as follows in [2]. It basically uses broadcasting to call itself, assuming it would call the non-AbstractArray
-typed version of itself. Unfortunately, theArrayFire
broadcasting calls the same function with arguments still asAbstractArray
, causing an endless loop.The solution could be a simple override of this function for
AFArray
:With this override, it all works. I am not sure if other overrides are necessary in more general cases, though. Although the Zygote developers could be summoned here, this would create a dependency between Zygote and ArrayFire, which is not really necessary. I am not sure that there is a cleaner way of solving the issue.
Best regards,
Danilo
[1]
Test example:
The error I get is:
[2]
From
Zygote.jl/src/lib/lib.jl
:The text was updated successfully, but these errors were encountered: