Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex numbers #159

Closed
willtebbutt opened this issue Apr 29, 2020 · 19 comments
Closed

Complex numbers #159

willtebbutt opened this issue Apr 29, 2020 · 19 comments
Labels
Complex Differentiation Relating to any form of complex differentiation

Comments

@willtebbutt
Copy link
Member

This was opened as a response to the discovery in this Zygote PR that our handling of complex functions isn't currently Zygote-compatible when they're not analytic.

The issue appears to stem from the need to conjugate the input / output to rrules when used in Zygote. This is a direct consequence of chain rules currently assuming that all functions for which rules are implemented are holomorphic, as in this case simply conjugating the input and output to the rule makes total sense and does the correct thing.

What needs to happen is:

  • ChainRulesCore to get a @holomorphic or @analytic macro for handling scalars involving complex numbers that does the conjugating for you
  • ChainRules to use this to implement the existing holomorphic functions
  • Careful testing of scalar functions with complex inputs in ChainRules (I'm assuming the lack of this is how stuff got through in the first place undetected)
  • A short PR to Zygote that simply stops conjugating things
@nickrobinson251 nickrobinson251 added the Complex Differentiation Relating to any form of complex differentiation label Apr 29, 2020
@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Apr 29, 2020

@oxinabox
Copy link
Member

I am not sure the conjugation and holomorphicness are related.
AFAICT Zygote just always stores and works with the conjugate of the complex gradient.
Just at all times, including returning it.
We work with the actual gradient.
At least it think the gradient we have is the actual one, since we agree with finite differencing where as Zygote gets the conjugte of the result that FiniteDifferences.jl gives

(cc @MikeInnes)

@willtebbutt
Copy link
Member Author

willtebbutt commented Apr 30, 2020

Hmm okay, then I'm slightly confused. The following:

julia> x = 5.0 + 4.0im
5.0 + 4.0im

julia> y = 4.0 + 3.0im
4.0 + 3.0im

julia> z, back = Zygote.pullback(+, x, y)
(9.0 + 7.0im, Zygote.var"#36#37"{Zygote.var"#1587#back#606"{Zygote.var"#604#605"{Complex{Float64},Complex{Float64}}}}(Zygote.var"#1587#back#606"{Zygote.var"#604#605"{Complex{Float64},Complex{Float64}}}(Zygote.var"#604#605"{Complex{Float64},Complex{Float64}}(5.0 + 4.0im, 4.0 + 3.0im))))

julia> back(1.0 + 1.0im)
(1.0 + 1.0im, 1.0 + 1.0im)

suggests that Zygote at least returns the cotangent, as opposed to the conjugate of the cotangent. Am I missing something @oxinabox ?

edit: it might be that I should write something up about the basics of AD + complex numbers if no one has already written anything.

@oxinabox
Copy link
Member

I am not sure about anything to do with complex number AD. so a document would be good

@sethaxen
Copy link
Member

I'd love to see that. Also how it connects to AD through functions with real inputs and outputs but complex intermediates.

@simeonschaub
Copy link
Member

The analytical complex derivative is defined as ∂/∂z = ∂/∂x - im*∂/∂y. What Zygote does is it returns ∂/∂x + im*∂/∂y, which is equivalent to the conjugate of the analytical derivative if the output is eventually real, which is always the case in Zygote.
@willtebbutt I believe reverse mode always pulls back one forms, which live on the cotangent space, but we might not always just want to use Adjoint{<:Vector} to represent them, as that doesn't generalize well to functions operating on higher dimensional arrays, structured derivatives (e.g. Composite) and also complex numbers, if you want to identify them with the linear functionals on R^2N, as the complex inner product on C^N is completely different to the real inner product on R^2N.
If you are interested, we could also discuss this in a call.

@willtebbutt
Copy link
Member Author

The analytical complex derivative is defined as ∂/∂z = ∂/∂x - im*∂/∂y.

🤦 you're right of course. I clearly need to revise some of the basics lol

If you are interested, we could also discuss this in a call.

Definitely interested. I've dropped you a message on slack.

@sethaxen
Copy link
Member

sethaxen commented May 8, 2020

At least it think the gradient we have is the actual one, since we agree with finite differencing where as Zygote gets the conjugte of the result that FiniteDifferences.jl gives

Does ChainRules agree with FiniteDifferences? e.g. is this the expected behavior (on FDv0.10.0)?

julia> using FiniteDifferences

julia> j′vp(central_fdm(5, 1), sin, 1.0 + 0.0im, 2.0 + 3.0im)
(-4.18962569096891 + 9.10922789375644im,)

julia> j′vp(central_fdm(5, 1), sqrt, 1.0 + 0.0im, 2.0 + 3.0im)
(0.23216272632545484 + 0.1242497204556949im,)
julia> using ChainRules

julia> ChainRules.rrule(sin, 2.0+3.0im)[2](1.0+0.0im)
(Zero(), -4.189625690968807 - 9.109227893755337im)

julia> ChainRules.rrule(sqrt, 2.0+3.0im)[2](1.0+0.0im)
(Zero(), 0.2321627263254075 - 0.12424972045565168im)

@oxinabox
Copy link
Member

oxinabox commented May 8, 2020

No that is not the expected behavour. At least not the behavour, I expected.
oh dear, clearly I do not know what i am talking about.
I thought I always had to conjugate Zygote's output to agree with FIniteDIfferences...
and I always had to conjugate ChainRules to agree with Zygote

@sethaxen
Copy link
Member

sethaxen commented May 8, 2020

I thought I always had to conjugate Zygote's output to agree with FIniteDIfferences...

For the complex rules I've checked, it seems like Zygote's adjoints generally agree with FD's j′vp's. e.g.

julia> A, B, C̄ = randn(ComplexF64, 3, 3), randn(ComplexF64, 3, 3), randn(ComplexF64, 3, 3);

julia> all(Zygote.pullback(*, A, B)[2](C̄) .≈ j′vp(central_fdm(5, 1), *, C̄, A, B))
true

julia> all(Zygote.pullback(inv, A)[2](C̄) .≈ j′vp(central_fdm(5, 1), inv, C̄, A))
true

@oxinabox
Copy link
Member

oxinabox commented May 11, 2020

@willtebbutt can you run this whole thing down with people and come back with an explination of what is it is exactly that we do differently to Zygote and if we are just wrong?
I won't have time to look at this this week

@MasonProtter
Copy link
Contributor

MasonProtter commented May 11, 2020

I'd say this is expected. The \prime in j′vp (as well as the word 'adjoint' in the docstring) suggests that j′vp(f, v, x) is adjoint((D(f))(x))*v where I take D(f)(x) to mean the derivative of f evaluated at x.

@MasonProtter
Copy link
Contributor

MasonProtter commented May 11, 2020

Note that FiniteDifferences.jl also has jvp which will agree with ChainRules (and the derivative you learned in Calc 1) here instead of Zygote:

julia> jvp(central_fdm(5, 1), sin, (2.0 + 3.0im, 1.0 + 0.0im))
 -4.18962569096891 - 9.109227893754833im

though for some reason it has a different interface than j′vp, requiring tuples of the form (x, v) where v is the value to be dotted into jacobian(f)(x)

@oxinabox
Copy link
Member

I'd say this is expected. The \prime in j′vp (as well as the word 'adjoint' in the docstring) suggests that j′vp(f, v, x) is adjoint((D(f))(x))*v where I take D(f)(x) to mean the derivative of f evaluated at x.

Should FiniteDifferences be using transpose though?

We should be able to work this out by actually using the result to perform gradient descent and seeing if out loss decrease right?

@MasonProtter
Copy link
Contributor

MasonProtter commented May 11, 2020

Should FiniteDifferences be using transpose though?

Depends on what your goal is, I guess. I'd say the point of providing j′vp is that it's the function that agree with reverse mode AD. If that's the case, it should not do transpose.

@MasonProtter
Copy link
Contributor

It's also worth nothing that FiniteDifferences.jacobian does the actual right thing that it seems nobody else does: compute the actual Jacobian without assuming sin is holomorphic:

julia> jacobian(central_fdm(5, 1), sin, 2.0 + 3.0im)[1]
 2×2 Array{Float64,2}:
  -4.18963   9.10923
  -9.10923  -4.18963

@willtebbutt
Copy link
Member Author

@willtebbutt can you run this whole thing down with people and come back with an explination of what is it is exactly that we do differently to Zygote and if we are just wrong?
I won't have time to look at this this week

This is very much on my todo list, but is unlikely to get done for a few weeks now because NeurIPS deadline.

@nickrobinson251
Copy link
Contributor

Discourse discussion https://discourse.julialang.org/t/taking-complex-autodiff-seriously-in-chainrules/39317

@sethaxen
Copy link
Member

sethaxen commented May 20, 2020

I ran a bit of this down:

  • Let's assume j′vp implements the jacobian-adjoint-vector-product (or equivalently the transpose of the vector-conjugate-jacobian-product (because j' v = (vᵀ j^*)ᵀ). (the documentation only says it computes an adjoint: https://www.juliadiff.org/FiniteDifferences.jl/latest/pages/api/#FiniteDifferences.j%E2%80%B2vp), but in the code, I only see a transpose, no conjugation.
    Edit: Also, it looks like FD changed conventions recently: Conjugation convention changed from v0.9 to v0.10 FiniteDifferences.jl#87

  • We also know that j′vp agrees with Zygote's adjoint rules. This makes sense, since Zygote documents that it computes j' v (https://fluxml.ai/Zygote.jl/latest/adjoints/#Pullbacks-1):

    Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. v'J, rather than the more usual J*v. Transposing v to a row vector and back (v'J)' is equivalent to J'v so our gradient rules actually implement the adjoint of the Jacobian.

  • ChainRules' docs in the same sentence claims j' v and jᵀ v http://www.juliadiff.org/ChainRulesCore.jl/dev/#The-propagators:-pushforward-and-pullback-1:

    The pushforward is jacobian vector product (jvp), and pullback is jacobian transpose vector product (j'vp).

  • It also claims it uses the adjoint of the Jacobian: http://www.juliadiff.org/ChainRulesCore.jl/dev/#Push-forwards-and-pullbacks-1:

    If you work out the action in a basis of the cotangent space, you see that it acts by the adjoint of the Jacobian.

    pull back by (in coordinates) multiplying with the adjoint of the Jacobian...

  • In the integration of ChainRules with Zygote (Add ChainRules FluxML/Zygote.jl#366), whenever cotangent vectors are passed between the two packages, they are conjugated. This passes Zygote's test suite. Since (j' v^*)^* = jᵀ v, this seems to indicate ChainRules uses the transpose.

  • ChainRulesTestUtils.rrule_test documents that it pulls back "adjoints", and it internally calls FiniteDifferences.j'vp with no conjugation. However, I tested locally that conjugating the vectors before and after pulling back does not cause any of ChainRules' tests to fail, indicating that none of the tests are sensitive to this convention. ChainRulesTestUtils.test_scalar do not use FiniteDifferences.j'vp (they seem to use a jacobian-vector-product), so only matrices are expected to be sensitive. However, not all of the AbstractMatrix rrules test the complex case. I found locally that the rrules for dot and inv follow the j'vp convention not the jᵀ v one.

In the longer term, we may want to take complex differentiation seriously, but in the short term we just need to have a consistent convention and ideally one we want to keep. Do we transpose or adjoint?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Complex Differentiation Relating to any form of complex differentiation
Projects
None yet
Development

No branches or pull requests

6 participants