Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove rules for conj, adjoint, and transpose #67

Merged

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Sep 17, 2021

This PR removes the rules for conj, adjoint, and transpose since they cause problems with ReverseDiff (#54 broke some tests for adjoint; see JuliaDiff/ReverseDiff.jl#183 (comment) and the other comments there for details) and the default fallbacks in https://github.com/JuliaLang/julia/blob/c5f348726cebbe55e169d4d62225c2b1e587f497/base/number.jl#L211-L213 should be sufficient (similar to the discussion about identity in #64).

I checked locally that the ReverseDiff issues are fixed by this PR.

Edit: I also added ReverseDiff to the integration tests.

@codecov-commenter
Copy link

codecov-commenter commented Sep 17, 2021

Codecov Report

Merging #67 (8dca5df) into master (d79d2d9) will decrease coverage by 0.14%.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #67      +/-   ##
==========================================
- Coverage   93.89%   93.75%   -0.15%     
==========================================
  Files           2        2              
  Lines         131      128       -3     
==========================================
- Hits          123      120       -3     
  Misses          8        8              
Impacted Files Coverage Δ
src/rules.jl 99.15% <ø> (-0.03%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d79d2d9...8dca5df. Read the comment docs.

@devmotion
Copy link
Member Author

ForwardDiff test errors seem unrelated and the same that JuliaDiff/ForwardDiff.jl#544 tries to address.

@devmotion
Copy link
Member Author

ModelingToolkit errors are the same as on the master branch: https://github.com/SciML/ModelingToolkit.jl/runs/3563293020

Copy link

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @devmotion - looks good to me!

@mcabbott
Copy link
Member

mcabbott commented Sep 17, 2021

Is there a short explanation of why these are wrong? Or why they lead ReverseDiff to the wrong answer?

I tried to read the thread but don't follow the details of its internals. There are no complex numbers involved. The rule does not apply to arrays, hence Adjoint etc. Why does it stumble on a function acting on real numbers which has derivative 1? It doesn't in isolation:

julia> ReverseDiff.gradient(x -> transpose(x[1])^2, [3.0])
1-element Vector{Float64}:
 6.0

yet you say "but the [rule] for transpose lead to a TrackedReal with a derivative of zero!". But why?

ForwardDiff used to define rules for these (or at least for conj) which I presume would need to be put back with a correlated PR, if they are removed here.

@devmotion
Copy link
Member Author

The short anwser is JuliaDiff/ReverseDiff.jl#183 (comment) 😛I guess my explanations were not completely clear though. I did not refer to derivatives of functions that involve transpose(::Real) (or one of the other functions). I wanted to refer to transpose(::TrackedReal) (and the other functions) which, as shown in the linked comment, are called when you retrieve or set values of an Transpose or Adjoint of tracked numbers. These are defined in https://github.com/JuliaDiff/ReverseDiff.jl/blob/01041c8e8237ed42f6414c6fe0f6e6b12162b6ac/src/derivatives/scalars.jl#L7. Thus these calls end up here https://github.com/JuliaDiff/ReverseDiff.jl/blob/16b35963234c398fc3a1eb42efab8516eac466e1/src/macros.jl#L84 which returns a TrackedReal with derivative 0.

I don't think this PR is problematic for ForwardDiff: for all these functions you want f(x::Dual) = x, which is exactly what the default definitions in Julia base are. Also the integration tests did not reveal any test failures (apart from the known random issue).

So in my opinion these rules are not helpful but actually cause problems in downstream packages, and hence it would be good to remove them (basically like the already removed rule for identity).

@ChrisRackauckas
Copy link
Member

@shashi you really need to prioritize fixing MTK master. This has gone on for way too long.

@mcabbott
Copy link
Member

I wanted to refer to transpose(::TrackedReal)

But how does my example not involve this? Presumably within a gradient call you will not get transpose(::Real) since the gradient is tracked? Like I said, I don't know the internals of the package. But it seems concerning if gradient 1 of a scalar function leads to gradient zero. Why doesn't this occur more widely?

@devmotion
Copy link
Member Author

When you compute ReverseDiff.gradient you have an additional reverse pass that accumulates the partial derivatives correctly: https://github.com/JuliaDiff/ReverseDiff.jl/blob/01041c8e8237ed42f6414c6fe0f6e6b12162b6ac/src/derivatives/scalars.jl#L47

@mcabbott
Copy link
Member

This still seems pretty weird to me. Is it clear to you when this is and isn't going to be triggered? Is it clear what properties other functions would have to have to trigger it?

julia> ReverseDiff.gradient(x -> (x' .+ x)[1], [3.0])  # wrong
1-element Vector{Float64}:
 1.0

julia> ReverseDiff.gradient(x -> (x' + x)[1], [3.0])  # maybe no broadcasting?
1-element Vector{Float64}:
 1.0

julia> ReverseDiff.gradient(x -> (x' + x')[1], [3.0])  # fine
1-element Vector{Float64}:
 2.0
 
 julia> ReverseDiff.gradient(x -> (x' + x')[1] + x[1], [3.0])  # fine, mixing adjoint & not
1-element Vector{Float64}:
 3.0
 
julia> ReverseDiff.gradient(x -> (x')[1] + x[1], [3.0])  # fine
1-element Vector{Float64}:
 2.0

@devmotion
Copy link
Member Author

I tried to explain this in the linked issue 🤷‍♂️ The problem occurs when we call increment_deriv! or decrement_deriv! in the reverse back with Adjoint or Transpose of TrackedReal: in this case the default indexing methods in base will call adjoint/transpose for elements of the wrapped array, and hence the accumulated derivatives become incorrect if we don't use the default adjoint(::TrackedReal) = x in base but instead the DiffRules-based method that returns a TrackedReal with x.deriv = 0. In some cases, eg. the example f(x) = sum(x' * x) from the issue, the gradients are still correct even though Adjoint or Transpose are involved since these functions use a special path (eg ReverseDiff contains special implementations for *(::Adjoint{<:TrackedReal}, ::AbstractMatrix) etc.).

Fixing indexing of Adjoint{<:TrackedReal} etc as discussed in the issue is not sufficient though: eg matrix division of TrackedArray still fails. It falls back to a division of Adjoint{<:TrackedReal}, and I assume at some point it calls adjoint etc. which is still not correct for TrackedReal. Since I think this PR is the more general and cleaner fix and I did not have more time, I did not continue debugging this error in more detail.

@mcabbott
Copy link
Member

and hence the accumulated derivatives become incorrect

Then it seems like the claim is that any array indexing which calls a function, for which a gradient is defined, may lead to problems? I tried a bit to trigger this with MappedArrays.jl but haven't managed to. Is this the zero gradient you refer to?

julia> sqrt(ReverseDiff.TrackedReal(2.0, 3.0))
TrackedReal<3vn>(1.4142135623730951, 0.0, ---, ---)

julia> transpose(ReverseDiff.TrackedReal(2.0, 3.0))
TrackedReal<2P5>(2.0, 0.0, ---, ---)

I don't oppose the quick fix, I'm just slightly disturbed by the ability for a mathematically not wrong rule to silently break something deep inside how it accumulates gradients. And wonder where else that can surface.

@devmotion
Copy link
Member Author

I haven't actively looked for any similar issues, I can try MappedArrays as well when I'm back on my computer. I'm by far not a ReverseDiff expert but, of course, there are bugs and issues about incorrect gradients (as in every AD system I assume), so I wouldn't be surprised if it can be reproduced in a similar setting 🤷‍♂️

I disagree though, I don't think this is a quick fix. In my opinion this is the correct thing to do, regardless of ReverseDiff. The default definitions in base already do the correct thing for dual and tracked numbers, so the definitions are not needed - and as we see (and already saw with identity) they do not even not help downstream packages but actually break stuff.

@shashi
Copy link

shashi commented Sep 17, 2021

@ChrisRackauckas i'm fixing it.

@mcabbott
Copy link
Member

base already do the correct thing for dual and tracked numbers, so

This seems to be the question. Is the sole purpose of this package is to provide rules for scalar operator-overloading AD? Is a rule for max or ifelse then forbidden because these should work without?

Or does "Many differentiation methods" in the readme include other things which may never see the Base definitions? Such as symbolic differentiation. Or generating rules for broadcasting.

@devmotion
Copy link
Member Author

I don't know, I based my opinion on the integration tests and the removal of identity 😄 Maybe the rules are useful at some point, maybe it would be useful to be able to distinguish and pick different subsets/types of rules 🤷‍♂️

But I guess this is something that should be discussed in a different issue. Since you said you're not opposed to the PR, I guess you're fine with me going ahead and merging it?

@devmotion
Copy link
Member Author

FYI I checked and with this PR all your examples above return the correct gradient.

@devmotion devmotion merged commit 6a001d2 into JuliaDiff:master Sep 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants