Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient wrt to a sparse matrix is mathematically wrong #1507

Open
mohamed82008 opened this issue Mar 24, 2024 · 6 comments · May be fixed by #1508
Open

Gradient wrt to a sparse matrix is mathematically wrong #1507

mohamed82008 opened this issue Mar 24, 2024 · 6 comments · May be fixed by #1508

Comments

@mohamed82008
Copy link
Contributor

Currently Zygote.gradient projects the co-tangent returned from pullback to have the same sparsity structure as the input. This is mathematically incorrect when the matrix input is sparse. According to Zygote, the following function has different gradients wrt the same inputs (mathematically speaking).

using Zygote, SparseArrays

Zygote.gradient(sum, zeros(1,1))[1] != Zygote.gradient(sum, spzeros(1,1))[1] # true
@mohamed82008 mohamed82008 linked a pull request Mar 24, 2024 that will close this issue
2 tasks
@mcabbott
Copy link
Member

That's quite a bold heading!

According to Zygote, the following function has different gradients wrt the same inputs (mathematically speaking).

I think you mean the inputs are == yet the gradients are different.

This is true, for example:

julia> Zygote.gradient(x -> real(exp(x * (1+im))), pi/2)
(-4.810477380965351,)

julia> Zygote.gradient(x -> real(exp(x * (1+im))), pi/2 + 0im)
(-4.810477380965351 - 4.810477380965351im,)

julia> pi/2 == pi/2+0im
true

That is, Zygote (really ChainRules) regards the type of the input as specifying the domain of the function, and hence the appropriate cotangent space in which the gradient lives. The fact that x -> real(exp(x * (1+im))) uses complex numbers internally is ignored, to view this as an R -> R function when x::Real.

Before ChainRules 1.0, Zygote did not do this. It regarded all numbers as living in C, and all matrices as living in C^N*M. I think this more or less fell out of how it works and Julia's type promotion rules. The fact that this function uses complex numbers internally would lead it to tell you to do gradient descent in a complex direction.

ChainRules applies such projections to almost every step. When a real number propagates forward through some complicated code for 10 steps & then gets promoted to complex for the 11th, the reverse pass projects its gradient back to real immediately, so that these 10 reverse steps involve only real numbers again.

Similar projections apply to most structured matrix types (such as Diagonal), and also to sparse arrays.

@mohamed82008
Copy link
Contributor Author

mohamed82008 commented Mar 24, 2024

Moving some of my discussion points over from Slack and polishing them a bit.

ChainRules projects the function's input's co-tangent to match the input's type. This can result in loss information where a dense (or otherwise structured but not sparse, e.g 1-rank matrix) gradient is projected onto a sparse structure. An example can be shown here https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/LinearAlgebra/dense.jl#L14. In my opinion, the projection should only happen at the function's output's co-tangent so the dot function's rule should be:

import ChainRulesCore: rrule, @thunk, ProjectTo, unthunk
import Zygote

mydot(x, y) = dot(x, y)
function rrule(::typeof(mydot), x::AbstractArray, y::AbstractArray)
	out = dot(x, y)
    project_out = ProjectTo(out)
    function dot_pullback(Ω̄)
        ΔΩ = project_out(unthunk(Ω̄))
        x̄ = @thunk(reshape(y .* ΔΩ', axes(x)))
        ȳ = @thunk(reshape(x .* ΔΩ, axes(y)))
        return (NoTangent(), x̄, ȳ)
    end
    return out, dot_pullback
end

instead of the current implementation:

function rrule(::typeof(dot), x::AbstractArray, y::AbstractArray)
    project_x = ProjectTo(x)
    project_y = ProjectTo(y)
    function dot_pullback(Ω̄)
        ΔΩ = unthunk(Ω̄)
        x̄ = @thunk(project_x(reshape(y .* ΔΩ', axes(x))))
        ȳ = @thunk(project_y(reshape(x .* ΔΩ, axes(y))))
        return (NoTangent(), x̄, ȳ)
    end
    return dot(x, y), dot_pullback
end

It seems that this was done on purpose to counter another issue which I am going to call it the "too much information" problem. For example, the suggested new mydot rrule will return a complex gradient for the first argument if we differentiate dot([1,2.], [3+im, 4.]).

julia> Zygote.pullback(real  mydot, [1,2], Complex.([1,2]))[2](1.0)
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im], ComplexF64[1.0 + 0.0im, 2.0 + 0.0im])

Is this wrong? Some people think so. I don't. The user clearly is mixing complex and real numbers so it's on them to project it if they only care about the real part or to explicitly declare that x is the real component only using:

julia> f(x, y) = real(mydot(Complex.(x), y))
f (generic function with 1 method)

julia> Zygote.pullback(f, [1,2], Complex.([1,2]))[2](1.0)
([1.0, 2.0], ComplexF64[1.0 + 0.0im, 2.0 + 0.0im])

To me this is more honest to the user's intentions at the risk of returning too much information which can be discarded easily by the project function being called by the user. So to summarise, I think projecting the function's input's co-tangent leads to information loss and potentially incorrect gradients for some applications.

@mcabbott
Copy link
Member

mcabbott commented Mar 24, 2024

In a long string of rules, projecting before every rule or after every rule will be broadly similar.

But what you seem to be arguing for is omitting projection at the very first rule. Won't this lead to all kinds of surprises? For example these two functions implement the same thing in slightly different ways... why should the user care if some library changes from one implementation to the other?

julia> Zygote.pullback(x -> real(mydot((3+4im) * x, x)), [1.0, 2.0])[2](1.0)
(ComplexF64[6.0 + 4.0im, 12.0 + 8.0im],)

julia> Zygote.pullback(x -> real((3-4im) * mydot(x, x)), [1.0, 2.0])[2](1.0)
([6.0, 12.0],)

(Using the lower-level pullback here, and above, avoids the fact that gradient applies projection to the final answer. Which is there in case some rules forgot, e.g. certain non-ChainRules ones within Zygote.)

Edit, in fact this example is even stranger, as rule for (3+4im) * x applying projection. With complex input, there is no imaginary part:

julia> Zygote.gradient(x -> real(mydot((3+4im) * x, x)), [1.0, 2.0 .+ 0im])
(ComplexF64[6.0 + 0.0im, 12.0 + 0.0im],)

@mohamed82008
Copy link
Contributor Author

I think the example above is an argument against projecting at all in the complex case, in that projecting complex numbers to real numbers leads to loss of information that can sometimes even violate the distributive property. So if there is a single complex number in the chain of calculations it should just propagate backward to the gradient unless users explicitly call Complex somewhere in the chain declaring that the input can only be the real component. So my new suggestion is to not project complex numbers to real numbers at all.

My new suggestion is now

mydot2(x, y) = dot(x, y)

function rrule(::typeof(mydot2), x::AbstractArray, y::AbstractArray)
   out = dot(x, y)
   function dot_pullback(Ω̄)
       ΔΩ = unthunk(Ω̄)
       x̄ = @thunk(reshape(y .* ΔΩ', axes(x)))
       ȳ = @thunk(reshape(x .* ΔΩ, axes(y)))
       return (NoTangent(), x̄, ȳ)
   end
   return out, dot_pullback
end

mymul(x, y) = x * y

function rrule(::typeof(mymul), x::Number, y::Number)
   out = mymul(x, y)
   function mul_pullback(Ω̄)
       ΔΩ = unthunk(Ω̄)
       x̄ = ΔΩ * y
       ȳ = x * ΔΩ
       return (NoTangent(), x̄, ȳ)
   end
   return out, mul_pullback
end

Zygote.pullback(x -> real(mymul((3+4im), mydot2(x, x))), [1.0, 2.0])[2](1.0)
# (ComplexF64[6.0 + 0.0im, 12.0 + 0.0im],)

julia> Zygote.pullback(x -> real(mymul((3+4im), Complex(mydot2(x, x)))), [1.0, 2.0])[2](1.0)
# ([6.0, 12.0],)

So basically calling the constructor function of a type is the primal of the projection operation, or the projection is the pullback of the constructor. This generalises nicely to matrices as well where if I want the gradient of x -> sum(Diagonal(x)) wrt x::Vector, projecting will happen naturally as the pullback of the constructor. This is distinctly different from the case when I want the gradient of sum wrt to Diagonal(x) where the diagonal representation is just a compact matrix representation.

@mcabbott
Copy link
Member

projecting complex numbers to real numbers leads to loss of information that can sometimes even violate the distributive property

Yes it loses information. I'm not sure what distributive property you mean.

But is passing in complex input in order to specify that you consider the domain to be C not R too much to ask?

It just seems extremely strange to me to want these two implementations of the same R -> R function to behave differently (as they did in earlier Zygote):

f1(x::Real) = real(exp(x * (1+im)))
f2(x::Real) = exp(x) * cos(x)
using Plots; plot(f1, -2, 2); plot!(f2, -2, 2)

It should be mentioned that the other kind of standardisation done by the projection machinery is to map structural cotangents to natural where possible. Here, by default x.re has a gradient (re = 1.0, im = nothing) but writing it as Complex(1.0, 0) makes generic rules work:

Zygote.gradient(x -> angle(x.re * x^2), 1+2im)  # (-0.8 + 0.4im,)

@mohamed82008
Copy link
Contributor Author

Projecting nothing and structural no tangents is fine. Your example is not surprising to me if they return different types because one function promotes types to the complex domain and the other doesn't. I wonder if anyone actually cares about these subtle cases of returning a complex derivative instead of a real one when they can easily call real on the resulting gradient. The bigger issue is doing the opposite on purpose, returning a real gradient when the user wanted the full complex gradient or returning a diagonal matrix gradient when the user wanted the gradient wrt the full matrix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants