-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complex numbers #159
Comments
I am not sure the conjugation and holomorphicness are related. (cc @MikeInnes) |
Hmm okay, then I'm slightly confused. The following: julia> x = 5.0 + 4.0im
5.0 + 4.0im
julia> y = 4.0 + 3.0im
4.0 + 3.0im
julia> z, back = Zygote.pullback(+, x, y)
(9.0 + 7.0im, Zygote.var"#36#37"{Zygote.var"#1587#back#606"{Zygote.var"#604#605"{Complex{Float64},Complex{Float64}}}}(Zygote.var"#1587#back#606"{Zygote.var"#604#605"{Complex{Float64},Complex{Float64}}}(Zygote.var"#604#605"{Complex{Float64},Complex{Float64}}(5.0 + 4.0im, 4.0 + 3.0im))))
julia> back(1.0 + 1.0im)
(1.0 + 1.0im, 1.0 + 1.0im) suggests that Zygote at least returns the cotangent, as opposed to the conjugate of the cotangent. Am I missing something @oxinabox ? edit: it might be that I should write something up about the basics of AD + complex numbers if no one has already written anything. |
I am not sure about anything to do with complex number AD. so a document would be good |
I'd love to see that. Also how it connects to AD through functions with real inputs and outputs but complex intermediates. |
The analytical complex derivative is defined as |
🤦 you're right of course. I clearly need to revise some of the basics lol
Definitely interested. I've dropped you a message on slack. |
Does ChainRules agree with FiniteDifferences? e.g. is this the expected behavior (on FDv0.10.0)? julia> using FiniteDifferences
julia> j′vp(central_fdm(5, 1), sin, 1.0 + 0.0im, 2.0 + 3.0im)
(-4.18962569096891 + 9.10922789375644im,)
julia> j′vp(central_fdm(5, 1), sqrt, 1.0 + 0.0im, 2.0 + 3.0im)
(0.23216272632545484 + 0.1242497204556949im,) julia> using ChainRules
julia> ChainRules.rrule(sin, 2.0+3.0im)[2](1.0+0.0im)
(Zero(), -4.189625690968807 - 9.109227893755337im)
julia> ChainRules.rrule(sqrt, 2.0+3.0im)[2](1.0+0.0im)
(Zero(), 0.2321627263254075 - 0.12424972045565168im) |
No that is not the expected behavour. At least not the behavour, I expected. |
For the complex rules I've checked, it seems like Zygote's adjoints generally agree with FD's julia> A, B, C̄ = randn(ComplexF64, 3, 3), randn(ComplexF64, 3, 3), randn(ComplexF64, 3, 3);
julia> all(Zygote.pullback(*, A, B)[2](C̄) .≈ j′vp(central_fdm(5, 1), *, C̄, A, B))
true
julia> all(Zygote.pullback(inv, A)[2](C̄) .≈ j′vp(central_fdm(5, 1), inv, C̄, A))
true |
@willtebbutt can you run this whole thing down with people and come back with an explination of what is it is exactly that we do differently to Zygote and if we are just wrong? |
I'd say this is expected. The |
Note that FiniteDifferences.jl also has julia> jvp(central_fdm(5, 1), sin, (2.0 + 3.0im, 1.0 + 0.0im))
-4.18962569096891 - 9.109227893754833im though for some reason it has a different interface than |
Should FiniteDifferences be using transpose though? We should be able to work this out by actually using the result to perform gradient descent and seeing if out loss decrease right? |
Depends on what your goal is, I guess. I'd say the point of providing |
It's also worth nothing that julia> jacobian(central_fdm(5, 1), sin, 2.0 + 3.0im)[1]
2×2 Array{Float64,2}:
-4.18963 9.10923
-9.10923 -4.18963 |
This is very much on my todo list, but is unlikely to get done for a few weeks now because NeurIPS deadline. |
I ran a bit of this down:
In the longer term, we may want to take complex differentiation seriously, but in the short term we just need to have a consistent convention and ideally one we want to keep. Do we transpose or adjoint? |
This was opened as a response to the discovery in this Zygote PR that our handling of complex functions isn't currently Zygote-compatible when they're not analytic.
The issue appears to stem from the need to conjugate the input / output to
rrule
s when used inZygote
. This is a direct consequence of chain rules currently assuming that all functions for which rules are implemented are holomorphic, as in this case simply conjugating the input and output to the rule makes total sense and does the correct thing.What needs to happen is:
@holomorphic
or@analytic
macro for handling scalars involving complex numbers that does the conjugating for youThe text was updated successfully, but these errors were encountered: