-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient wrt to a sparse matrix is mathematically wrong #1507
Comments
That's quite a bold heading!
I think you mean the inputs are This is true, for example: julia> Zygote.gradient(x -> real(exp(x * (1+im))), pi/2)
(-4.810477380965351,)
julia> Zygote.gradient(x -> real(exp(x * (1+im))), pi/2 + 0im)
(-4.810477380965351 - 4.810477380965351im,)
julia> pi/2 == pi/2+0im
true That is, Zygote (really ChainRules) regards the type of the input as specifying the domain of the function, and hence the appropriate cotangent space in which the gradient lives. The fact that Before ChainRules 1.0, Zygote did not do this. It regarded all numbers as living in C, and all matrices as living in C^N*M. I think this more or less fell out of how it works and Julia's type promotion rules. The fact that this function uses complex numbers internally would lead it to tell you to do gradient descent in a complex direction. ChainRules applies such projections to almost every step. When a real number propagates forward through some complicated code for 10 steps & then gets promoted to complex for the 11th, the reverse pass projects its gradient back to real immediately, so that these 10 reverse steps involve only real numbers again. Similar projections apply to most structured matrix types (such as Diagonal), and also to sparse arrays. |
Moving some of my discussion points over from Slack and polishing them a bit. ChainRules projects the function's input's co-tangent to match the input's type. This can result in loss information where a dense (or otherwise structured but not sparse, e.g 1-rank matrix) gradient is projected onto a sparse structure. An example can be shown here https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/LinearAlgebra/dense.jl#L14. In my opinion, the projection should only happen at the function's output's co-tangent so the import ChainRulesCore: rrule, @thunk, ProjectTo, unthunk
import Zygote
mydot(x, y) = dot(x, y)
function rrule(::typeof(mydot), x::AbstractArray, y::AbstractArray)
out = dot(x, y)
project_out = ProjectTo(out)
function dot_pullback(Ω̄)
ΔΩ = project_out(unthunk(Ω̄))
x̄ = @thunk(reshape(y .* ΔΩ', axes(x)))
ȳ = @thunk(reshape(x .* ΔΩ, axes(y)))
return (NoTangent(), x̄, ȳ)
end
return out, dot_pullback
end instead of the current implementation: function rrule(::typeof(dot), x::AbstractArray, y::AbstractArray)
project_x = ProjectTo(x)
project_y = ProjectTo(y)
function dot_pullback(Ω̄)
ΔΩ = unthunk(Ω̄)
x̄ = @thunk(project_x(reshape(y .* ΔΩ', axes(x))))
ȳ = @thunk(project_y(reshape(x .* ΔΩ, axes(y))))
return (NoTangent(), x̄, ȳ)
end
return dot(x, y), dot_pullback
end It seems that this was done on purpose to counter another issue which I am going to call it the "too much information" problem. For example, the suggested new julia> Zygote.pullback(real ∘ mydot, [1,2], Complex.([1,2]))[2](1.0)
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im], ComplexF64[1.0 + 0.0im, 2.0 + 0.0im]) Is this wrong? Some people think so. I don't. The user clearly is mixing complex and real numbers so it's on them to project it if they only care about the real part or to explicitly declare that julia> f(x, y) = real(mydot(Complex.(x), y))
f (generic function with 1 method)
julia> Zygote.pullback(f, [1,2], Complex.([1,2]))[2](1.0)
([1.0, 2.0], ComplexF64[1.0 + 0.0im, 2.0 + 0.0im]) To me this is more honest to the user's intentions at the risk of returning too much information which can be discarded easily by the project function being called by the user. So to summarise, I think projecting the function's input's co-tangent leads to information loss and potentially incorrect gradients for some applications. |
In a long string of rules, projecting before every rule or after every rule will be broadly similar. But what you seem to be arguing for is omitting projection at the very first rule. Won't this lead to all kinds of surprises? For example these two functions implement the same thing in slightly different ways... why should the user care if some library changes from one implementation to the other? julia> Zygote.pullback(x -> real(mydot((3+4im) * x, x)), [1.0, 2.0])[2](1.0)
(ComplexF64[6.0 + 4.0im, 12.0 + 8.0im],)
julia> Zygote.pullback(x -> real((3-4im) * mydot(x, x)), [1.0, 2.0])[2](1.0)
([6.0, 12.0],) (Using the lower-level Edit, in fact this example is even stranger, as rule for
|
I think the example above is an argument against projecting at all in the complex case, in that projecting complex numbers to real numbers leads to loss of information that can sometimes even violate the distributive property. So if there is a single complex number in the chain of calculations it should just propagate backward to the gradient unless users explicitly call My new suggestion is now mydot2(x, y) = dot(x, y)
function rrule(::typeof(mydot2), x::AbstractArray, y::AbstractArray)
out = dot(x, y)
function dot_pullback(Ω̄)
ΔΩ = unthunk(Ω̄)
x̄ = @thunk(reshape(y .* ΔΩ', axes(x)))
ȳ = @thunk(reshape(x .* ΔΩ, axes(y)))
return (NoTangent(), x̄, ȳ)
end
return out, dot_pullback
end
mymul(x, y) = x * y
function rrule(::typeof(mymul), x::Number, y::Number)
out = mymul(x, y)
function mul_pullback(Ω̄)
ΔΩ = unthunk(Ω̄)
x̄ = ΔΩ * y
ȳ = x * ΔΩ
return (NoTangent(), x̄, ȳ)
end
return out, mul_pullback
end
Zygote.pullback(x -> real(mymul((3+4im), mydot2(x, x))), [1.0, 2.0])[2](1.0)
# (ComplexF64[6.0 + 0.0im, 12.0 + 0.0im],)
julia> Zygote.pullback(x -> real(mymul((3+4im), Complex(mydot2(x, x)))), [1.0, 2.0])[2](1.0)
# ([6.0, 12.0],) So basically calling the constructor function of a type is the primal of the projection operation, or the projection is the pullback of the constructor. This generalises nicely to matrices as well where if I want the gradient of |
Yes it loses information. I'm not sure what distributive property you mean. But is passing in complex input in order to specify that you consider the domain to be C not R too much to ask? It just seems extremely strange to me to want these two implementations of the same R -> R function to behave differently (as they did in earlier Zygote): f1(x::Real) = real(exp(x * (1+im)))
f2(x::Real) = exp(x) * cos(x)
using Plots; plot(f1, -2, 2); plot!(f2, -2, 2) It should be mentioned that the other kind of standardisation done by the projection machinery is to map structural cotangents to natural where possible. Here, by default Zygote.gradient(x -> angle(x.re * x^2), 1+2im) # (-0.8 + 0.4im,) |
Projecting |
Currently
Zygote.gradient
projects the co-tangent returned frompullback
to have the same sparsity structure as the input. This is mathematically incorrect when the matrix input is sparse. According to Zygote, the following function has different gradients wrt the same inputs (mathematically speaking).The text was updated successfully, but these errors were encountered: