-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type Constraints #232
Comments
I would argue that rules like
and
are reasonable definitions, but they correspond to two different
(The exact form of the To demonstrate how such |
Hmm interesting proposal. If we conclude that they're both reasonable definitions, then that's certainly something that we should consider. Another way of looking at this problem (that I forgot to mention before) is to consider what would happen in an AD system where you hadn't defined an So the argument goes as follows:
As regards whether 1 holds or not -- you certainly need to have already defined the You could think of this argument as being one about consistency. It leans on the idea that there are a small set of basic rules, on whose behaviour everyone can agree, and that all other behaviour follows from there. |
I very much agree that it should be possible to split ChainRules into a few rules which define the behaviour and many other rules which provide performance improvements but which could be omitted without breaking any code. I doubt that you can do AD in a sensible way without this property, because it would mean the interface that a downstream programmer works with would depend on what optimisations are currently implemented. I believe the core, API-defining rules are
If we decide not to go with the
If we tried to follow the diagonal-as-matrix rule, then we would get into trouble defining the
|
This is undecidable without knowing which AD you are talking about. Further for many, providing rules only for |
I guess a more accurate formulation of the point I was trying to make is that we must be able to guarantee that the many rules we might define for a particular type are all compatible with one another, and one way (perhaps the only way?) to do that is to define the core rules and then require that any other rule must have the same signature as if some hypothetical AD tool had created the rule for us. I understand that different concrete AD tools have different capabilities, but I assume when their capabilities overlap they agree in what should be happening, and so I further assume that an ideal AD tool which can create all rules from the minimal core set is well defined. Also, I am of course not advocating that ChainRules should not provide optimisation rules, just that these rules must be consistent with what we would get if they weren't there. |
fair enough. That is a decent way to put it. One problem is natural differential types. As a rule an AD system can only make a structured differential types (i.e. |
Have you written this example down anywhere @oxinabox ? Would be good to gather these ideas here. In particular I would really like a crystal clear argument of why we should prefer to work with natural differentials. |
The
I believe this is a question of convention. We could either require that all types use Also, note that "differential type" in the above is meant in a duck-typing sense. It is okay for a type to choose |
I keep getting sidetracked on writing this perspective up (it's long), so I'll just give the short version and apologize if it's unclear or inaccurate. I agree that ideally the right thing to do would be to write rules that do exactly what AD would have done without the rule, just more efficiently. Right now I really don't like the idea of only defining rules on functions of concrete types though. I implement a lot of custom array types, as do others, and I rely on rules defined on abstract arrays. Without them, e.g. using Zygote, I'm saddled with the (current) terrible performance of the Treating abstract arrays as embedded in the arrays, it's safe to pull back dense array adjoints, and as long as 1) the initial adjoint was valid and 2) the But as @willtebbutt points out, you lose the time complexity of the primal function in the pullback. One way to partially get this back would be to "project" the output of each pullback to the predetermined differential type for its primal with a utility function. All rules with abstract types would do this. So here's an example for function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
function times_pullback(Ȳ)
∂A = @thunk(project_cotangent(*, A, Ȳ * B'))
∂B = @thunk(project_cotangent(*, A, A' * Ȳ))
return (NO_FIELDS, ∂A, ∂B)
end
return A * B, times_pullback
end
# defaults ignores function argument
project_cotangent(f, x, ∂x) = project_cotangent(x, ∂x)
# defaults to a no-op
project_cotangent(x, ∂x) = ∂x
# some possible projections
project_cotangent(x::Array, ∂x) = Array(∂x)
project_cotangent(x::Array, ∂x::Array) = convert(typeof(x), ∂x)
project_cotangent(x::Diagonal, ∂x) = Diagonal(∂x)
project_cotangent(x::LowerTriangular, ∂x) = LowerTriangular(∂x)
project_cotangent(x::UpperTriangular, ∂x) = UpperTriangular(∂x)
project_cotangent(x::Adjoint, ∂x::Adjoint) = ∂x
project_cotangent(x::Adjoint, ∂x) = Adjoint(project_cotangent(x.parent, adjoint(∂x)))
project_cotangent(x::Transpose, ∂x::Transpose) = ∂x
project_cotangent(x::Transpose, ∂x) = Transpose(project_cotangent(x.parent, transpose(∂x)))
project_cotangent(x::Symmetric, ∂x) = Symmetric(project_cotangent(x.data, symmetrize(∂x)), x.uplo)
project_cotangent(x::Hermitian, ∂x) = Hermitian(project_cotangent(x.data, hermitrize(∂x)), x.uplo)
symmetrize(x) = (x .+ transpose(x)) ./ 2
hermitrize(x) = (x .+ x') ./ 2 I don't think this guarantees the right time complexity, but by preserving type information that can then be used for dispatch internally within this and subsequent pullbacks, it will likely be more efficient than the pullback using dense arrays. The downside of this approach is that if someone implements a new array type but doesn't define a rule for the constructor, then there's a good chance they won't be able to use AD with the array. |
I agree with you that we're going to have to make some compromises in the short- to medium- term here to not completely break Zygote. There are probably a few options, and I like the one that you suggest @sethaxen, particularly for matrices that are (often) otherwise dense-with-constraints e.g. the One thing that might be worth working out how custom rules for cases where the asymptotic complexity would otherwise take a hit e.g. |
At least some of it should be written down in: https://www.juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html The best example I have for why one wants to work with Natural Differentials, For Cholesky: you want to work with the properties ChainRules.jl/src/rulesets/LinearAlgebra/factorization.jl Lines 87 to 100 in d3cd83e
I don't think you don't want to deal with the true field factors https://github.com/JuliaLang/julia/blob/110765a87af68120f2f9f4aa0bbc4054db491359/stdlib/LinearAlgebra/src/cholesky.jl#L119- Though maybe I am wrong in this case, since the relationship of Better example is |
Hmm well I think I'm probably okay with using differentials that aren't the I think this is probably the same reason that I think it's probably fine to represent the (co)tangent of a |
Related is JuliaDiff/ChainRulesCore.jl#176. e.g. if |
Ohhh so potentially it's generally the case that (if we adopt a strict convention) it's okay for a "bigger" type to have a differential represented by a "smaller" type, but not vice versa. So
etc. |
That's one idea. The thinking is that the (co)tangent should be (in some sense) embedded in the (co)tangent space of the primal, which could mean it is in the (co)tangent space of a submanifold that also contains the primal or could mean that it is just in a subset of the (co)tangent space of the same manifold. I do think though that this is a trickier rule to enforce than the |
Here's another case study relating to this issue: https://discourse.julialang.org/t/43778 |
I am not sure I understand the purpose of |
Yup, that's exactly it. It's a stop-gap to help out AD systems that don't handle mutation properly and so rely heavily on pullbacks (I'm really just looking at Zygote here tbh) edit: ofc we would set things up such that you can always optimise stuff still, and we would in all cases we have the time / inclination to sort out. |
Any given pullback runs the risk of the cotangent wandering away from the cotangent space of the primal. In general it will be the same cotangent, just embedded in a larger cotangent space. For example, in This is problematic for 2 reasons: 1) the representation in the larger cotangent space will generally be less efficient, where operations have a worse time complexity than the primal. 2) as the program becomes larger and more complex with more user-defined rules, there's a higher chance that a pullback is encountered that makes assumptions about the cotangent (such as "the cotangent is real") that are violated by this cotangent vector. This can result in the computed gradient just being wrong.
|
Why does mutation matter in this context? The
Isn't this exactly the same as with standard dispatch? For example, when you create a type As far as I can tell, it's the same here. |
Re integers, I guess the possible options are (1) they are categorical, all derivatives Zygote takes (3) right now. That's also the rule of every blackboard lecture on calculus ever. Option (1) would break for instance every example in Zygote's readme, which seems surprising. Under option (2), I don’t think a choice here implies rules for complex (or matrix) types. Integers really are different. You can do calculus over R, or over C (or worse), and over R^N etc, but you can’t do calculus over the integers. So either you don’t at all, (1), or you declare them to be real numbers which happen take up less blackboard space, (3). I hadn’t seen #224, but it doesn’t do anything special for integers right? It won’t produce a complex gradient for the power in |
Zygote's machinery works as intended / custom The usual rationale around falling back to generic implementations when you've not got a specialised one in Julia does not apply to AD. The reason that it's acceptable to fall back to a slow definition in general Julia code is because you've not written specialised code, so a slow fallback is better than not having anything. Going with the running example of a The same is not true for AD. Suppose you have indeed written If, on the other hand, one implements an adjoint for This is analogous to what is happening here with
Apologies, I definitely wasn't sufficiently clear before. In rrule(::typeof(logdet), X::AbstractMatrix) = ... and in @adjoint logdet(C::Cholesky) = ...
LinearAlgebra.logdet(a::PDMat) = logdet(a.chol) where The desired behaviour is that |
Regarding your comment on the integers: I sympathise with your position on the matter, and agree that it would be quite jarring from a user's perspective (and mine, when I'm being a user). It would certainly be possible for a given AD tool to take a different strategy from ChainRules at the user-facing level. For example, the default behaviour might be to convert I still believe that consistency has to be the goal of a rules system though.
Maybe @sethaxen can comment on this? I believe he wrote the code and understands the issue best. |
To be clear, you are arguing for what I called option (1), or option (2)? And can you phrase exactly what you mean by consistency here? One way to phrase it is "use the tangent space implicitly defined by dual numbers". This has 1 real dimension in the case of both floats & integers, 2 real dimensions for ComplexF64. And N not N^2 real dimensions in the case of a Diagonal matrix. |
Consistency suggests (2), but the implication of not allowing a float-valued tangent for an
I agree with the dimensionality aspect of what you're saying, and I agree with your characterisation of the dimensionality of the various tangent spaces defined my the dual numbers. Although now that you're forcing me to be really specific (thanks for engaging with this, this is a helpful discussion) I think my point about consistency has been conflating two distinct issues:
We seem to have some kind of consensus around point number 1. Our discussions around point 2 is trickier though. There are clearly situations in which someone writes I don't think that we have the ability to distinguish between the two in general other than from context. Seth's example with matrix powers feels like an example where it really matters, but I'm not completely sure. |
Powers are tricky, but are the two problems obviously entangled? I wonder whether, in cases where previously Zygote returned a complex gradient for a real power, it shouldn't return something special to say "don't trust this gradient" which was often not needed anyway. Maybe a variant of Re dual numbers, if And, a corner case is whether a SparseMatrix has dimension N^2, or
Re types in general, should |
Thanks for explaining why methods come with additional headache in AD as compared to "standard" Julia, @willtebbutt ! This issue has been mentioned in a number of discussions in this repo, and I never quite understood what the problem was, but now I do. However, I am not sure I agree with the proposed solution. The fundamental problem is that we have two different axes along which we would like the compiler / AD tool to auto-generate code for us, namely the specialisation and differentiation axes. What you suggest in your post is that specialisation should take precedence over differentiation, but I am not sure that this is the right thing to do in all circumstances. Maybe this situation is the AD equivalent of an ambiguity error which needs to be manually resolved by a human? |
The above issue is yet another example of this issue cropping up in practice. It's an interesting example that feels morally a bit different from the other examples here in that it's a rule that doing quite high-level stuff. A couple of other misc. thoughts on this, in no particular order:
Not sure how helpful either of these are, but maybe food for thought 🤷 |
We have another contender! FluxML/Zygote.jl#899 |
I think the “type constraints” label is rolling together several levels of too-broad definitions:
This last class is the main concern here:
I don't think I appreciated this before, but the design of ChainRules is such that you can disable a generic |
Care also needs to be taken that this mechanism works for Nabla, and Yota. |
Oh nice, I didn't know Yota was on board, since dfdx/Yota.jl#85 it seems. And it looks like invenia/Nabla.jl#189 is the Nabla PR. Both look to be about as tricky as Zygote, i.e. they aren't just calling the |
closed by JuliaDiff/ChainRulesCore.jl#385 |
Various discussions have been had in various places about the correct kinds of types to implement
rrule
s for, but we've not discussed this in a central location. This problem probably occurs for somefrule
s, but doesn't seem as prevalent as in therrule
case.Problem Statement
The general theme of the problem is whether or not to view certain types as being "embedded" inside others or not, for the purpose of computing derivatives. For example, is a
Diagonal
matrix one that just happens to be diagonal and is equally-well thought of as aMatrix
, or is it really it's own thing? Similarly, is anInteger
just aReal
, or is it something else?I will first attempt to demonstrate the implications of each choice with a couple of representative examples.
Diagonal
matrix multiplicationConsider the following
rrule
implementation:If
X
andY
areMatrix{Float64}
s for example, then this is a perfectly reasonable implementation --ΔΩ
should also be aMatrix{Float64}
if whoever is calling this rule is calling it correctly.Things break down if
X
is aDiagonal{Float64}
. The forwards-pass is completely fine, as is the computation of the cotangent forY
,X' * ΔΩ
. However, the complexity of the cotangent representation / computation forX
is now very concerning --ΔΩ * Y'
produces aMatrix{Float64}
. Such a matrix is specified byO(N^2)
numbers rather thanO(N)
required forX
, and requiresO(N^3)
-time to compute, as opposed to the forwards-pass complexityO(N^2)
. This breaks the promise that the forwards- and reverse-pass time- and memory-complexity of reverse-mode AD should be the same, in essence rendering the structure in aDiagonal
matrix if used in an AD system where this is the only rule for multiplication of matrices.Moreover, what does it mean to consider a non-zero "gradient" w.r.t. the off-diagonal elements of a
Diagonal
matrix? If you take the view that it's just anotherMatrix
, then there's no issue. The other point of view is that there's no meaningful way to define a non-zero gradient w.r.t. the off-diagonal elements of aDiagonal
matrix without considering matrices outside of the space ofDiagonal
matrices -- intuitively, if you "perturb" an off-diagonal element, you no longer have aDiagonal
matrix. Consequently, aMatrix
isn' an appropriate type to represent the gradient of aDiagonal
. If someone has a way to properly formalise this argument, please consider providing it.It seems that the first view necessitates giving up on the complexity guarantees that reverse-mode AD provides, while the second view necessitates giving up on implementing
rrule
s for abstract types (roughly speaking). The former is (in my opinion) a complete show-stopper, while the latter is something we can in principle live with.Of course you could add a specialised implementation for
Diagonal
matrices. However, I would suggest that you ponder your favourite structured matrix type and try to figure out whether it has similar issues. Most of the structured matrix types that I have encountered suffer from precisely this issue with many operations defined on them -- only those that are "dense" in the same way that aMatrix
is do not. Consequently, it is not the case that we'll eventually reach a point where we've implemented enough specialised rules -- people will keep creating new subtypes ofAbstractMatrix
and we'll be stuck in a cycle of forever adding new rules. This seems sub-optimal given that a reasonable AD aught to be able to derive them for us. Moreover, whenever someone who isn't overly familiar with AD implements a newAbstractMatrix
, they would need to implement a host of newrrule
s, which also seems like a show-stopper.Number
multiplicationNow consider implementing an
rrule
for*
between two numbers. Clearlyis a correctly implemented rule.
Float64
is concrete, so there's no chance that someone will subtype it and require a different implementation for their subtype. In this sense, we can guarantee that this rule is correct for any of the inputs that it admits (up to finite-precision arithmetic issues).What would happen if you implemented this rule instead for
Real
s? Suppose someone provided anInteger
argument fory
, then its cotangent will be probably be aFloat64
. While this doesn't provide the same complexity issues as theDiagonal
example above, treating theInteger
s as being embedded in theReal
s can cause some headaches, such as the one's that @sethaxen addressed in #224 -- where it becomes very important to distinguish betweenInteger
andReal
exponents for the sake of performance and correctness. Since it is presumably not acceptable to sometimes treat theInteger
s as special cases of theReal
s and some times not, it follows that*
should not be implemented betweenReal
s, but betweenAbstractFloat
s if we're actually trying to be consistent.Will this cause issues for users? Seth pointed out that the only situation in which this is likely to be problematic is the case in which an
Integer
argument is provided to an AD tool. This doesn't seem like a show stopper.What Gives?
The issue seems to stem from implementing rules for types that you don't know about. For example, you can't know whether the
*
implementation above is suitable for all concrete matrices that sub-typeAbstractMatrix
, even if they otherwise seem like perfectly reasonable matrix types.How does this square with the typical uses of multiple dispatch within Julia? One common line-of-thought is roughly "multiple dispatch seems to work just fine in the rest of Julia, so why can't we just implement things as they're needed in this case?". The answer seems to be that
AbstractMatrix
s will generally give the correct answer. This simply doesn't hold if you're inclined to take the view that aMatrix
isn't a suitable type to represent the gradient w.r.t. aDiagonal
, andWhat To Do?
The obvious thing to do is to advise that rules are only implemented when you know they'll be correct, which means you have to restrict yourself to implementing rules for concrete types, and collections thereof. Unfortunately, doing this will almost certainly break e.g. Zygote because it relies heavily on very broadly-defined rules that in general exhibit all of the problems discussed above. Maybe some careful middle ground needs to be found, or maybe Zygote needs to press forward with it's handling of mutation so that it can actually handle the things that it can deal with such changes.
Related Work
#81
Writing this issue was motivated by #226, where @nickrobinson251 pointed out that we should have an issue about this, and that we should consider adding something to the docs.
@sethaxen @oxinabox @MasonProtter any thoughts?
@DhairyaLGandhi @CarloLucibello this is very relevant for Zygote, so could do with your input.
The text was updated successfully, but these errors were encountered: