-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support chunked frule #92
Comments
It's just like reverse mode. The partials can be any thing, like an array. A nice way to handle chunking is to have the partials be a matrix, and then a lot of operations will naturally push forward Jv simultaneously where the v's are each column of the seed matrix. I don't get what's broadcasting? I didn't know it was, but if it was automatically broadcasting then it shouldn't: exploiting the linear algebra is crucial to doing this fast, just like reverse mode. |
See broadcast: |
Idea is that people want to concurrently propagate multiple sensitivity values for the same primal value. And if you do this with a basis matrix, you get the jacobian in that basis. Now what does this mean in practical terms: If those are just 1 derivative (which is the current case we think of) If they are collections of derivatives, ie. Right now this or more or less works in practice, Once they are not. it gets harder. This is great |
It seems to be that it should be fairly straightforward to generalise our existing framework to account for this. I completely understand what @ChrisRackauckas is getting at. It would be helpful if you could specify the extent of your interface as it stands.
We'll need to come up with some interface that plays nicely with all of the above, and plays nicely with our One other question @ChrisRackauckas @YingboMa @shashi -- assuming that you do indeed represent a chunk of differentials w.r.t. a As with #91 , I'm keen to address this quickly so that we can all press forwards. |
Main reason to want this is to compute jacobians, or nontrivial parts of jacobians. One thing this does is rather than a loop of MatrixVector operations, Further more though, becuase forward mode has fused pushforward. #74 where as its less important with seperate pullback in reverse mode, because then you don't redo the primal computation. |
When chunking, we have We don't force Base.ndims(::Zero) = 0
# inferable `argmax(map(ndims, partials))`
@generated _argmax_ndims(partials) = :(partials[$(argmax(map(ndims, partials.parameters)))])
ChainRules.extern(::Zero, partials) = zero(_argmax_ndims(partials)) which gives julia> extern(Zero(), (1, Zero(), [1 2; 3 4], [1, 2]))
2×2 Array{Int64,2}:
0 0
0 0 so |
Can we be chunking some sensitivities and not chunking others? |
FD2 doesn't do forward mode AD on multiple arguments. It is possible to do, but aware that when two chunks meet, the larger one wins. |
Actually, I don't think the above |
Makes sense.
Also makes sense. Could you provide an example of where this behaviour is currently being exploited to accelerate inference? Do we have any Given our intention to drop struct ChunkedZero <: AbstractDifferential
chunk_size::Int
end Coupled with the primal, I'm pretty sure this provides enough information to know what to do. i.e. you've retained the chunk-size information, which is the only extra bit of information you need AFAICT. As regards chunking for structured stuff, I don't know if there's really anything that we need to do -- presumably everything just work out recursively... |
accelerate inference?
julia> using LinearAlgebra
julia> x = rand(3);
julia> dx = rand(3, 3);
julia> frule(BLAS.asum, x, Zero(), dx)
(0.7486352425663505, 5.584182515140222)
julia> sum(i->frule(BLAS.asum, x, Zero(), dx[:, i])[2], 1:3)
5.584182515140222
Yes, I agree, |
Would we end up with |
Sorry, had my probabilistic programming hat on by accident. I meant examples where chunked computations are done at the minute, and are faster than naively iterating over each e.g. column vector of differentials. The
Sorry @oxinabox could you expand on this? I'm not quite sure what is meant. |
|
A question is how do we support chucked frule in the case of the author only writing a nonchunked frule. If we don't do that, another option is: we have:
This seems intense though. This is a motivation for having both fused and nonfused frule, |
I think we just don't want to support nonchunked |
One challenge is supporting chunking on functions that don’t use their inputs, e.g. zero arg constructors. This is needed to support forward mode mutation |
Has there been a conclusion on this? |
Not really. |
Ok. Here are my two cents. From the discussion in JuliaDiff/ChainRules.jl#232, it is becoming increasingly clear to me that it would probably be good if every Julia type
This scheme can then straightforwardly be extended to include chunking: every type should also specify a chunked differential type
I further agree with @oxinabox that |
from #90 it seems that @YingboMa wants
frule
to be able to be called on aVector
of sensitivies for the same primal value,and get a of sensitivities vector back,
but without broadcasting ? (presumably because that would also recompute the forwad primal)
I don't understand properly.
So this thread is get @YingboMa @shashi or @ChrisRackauckas to explain that.
This might need a redesign of
frrule
again similar to solving #74.Maybe we went too far there, since broadcasting the pushforward would presumably solve that case.
The text was updated successfully, but these errors were encountered: