Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems with Zygote and the PDMat constructor #159

Open
simsurace opened this issue May 4, 2022 · 21 comments
Open

Problems with Zygote and the PDMat constructor #159

simsurace opened this issue May 4, 2022 · 21 comments

Comments

@simsurace
Copy link

simsurace commented May 4, 2022

An error is thrown when differentiating a trace of a matrix division with a PDMat:

using LinearAlgebra
using PDMats
using Zygote

function kernel(x)
    return [1. x; x 1.]
end

# PDMat is basically a wrapper for a cholesky decomposition.
# However, using `cholesky` explicitly does not throw any errors:
f(x) = tr(cholesky(kernel(0.1)) \ kernel(x))
Zygote.gradient(x->f(only(x)), [.2]) # works

# Trying to perform the same operation through the `PDMat` constructor fails:
g(x) = tr(PDMat(kernel(0.1)) \ kernel(x))
Zygote.gradient(x->g(only(x)), [.2]) # ERROR
ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{PDMat{Float64, Matrix{Float64}}, Nothing, false})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/lib/lib.jl:324
  [3] (::Zygote.var"#1784#back#228"{Zygote.Jnew{PDMat{Float64, Matrix{Float64}}, Nothing, false}})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ~/.julia/packages/PDMats/ovlmf/src/pdmat.jl:9 [inlined]
  [5] (::typeof(∂(PDMat{Float64, Matrix{Float64}})))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/PDMats/ovlmf/src/pdmat.jl:16 [inlined]
  [7] (::typeof(∂(PDMat)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/PDMats/ovlmf/src/pdmat.jl:19 [inlined]
  [9] (::typeof(∂(PDMat)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[7]:1 [inlined]
 [11] (::typeof(∂(f)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [12] Pullback
    @ ./REPL[8]:1 [inlined]
 [13] (::typeof(∂(#3)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [14] (::Zygote.var"#56#57"{typeof(∂(#3))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:41
 [15] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:76
 [16] top-level scope
    @ REPL[8]:1

This seems like a strange error. I tried to reproduce with my own type, but couldn't:

# The following is basically copy-pasted from master:
struct MyPDMat{T<:Real,S<:AbstractMatrix}
    dim::Int
    mat::S
    chol::Cholesky{T,S}

    MyPDMat{T,S}(d::Int,m::AbstractMatrix{T},c::Cholesky{T,S}) where {T,S} = new{T,S}(d,m,c)
end

function MyPDMat(mat::AbstractMatrix,chol::Cholesky{T,S}) where {T,S}
    d = size(mat, 1)
    size(chol, 1) == d ||
        throw(DimensionMismatch("Dimensions of mat and chol are inconsistent."))
    MyPDMat{T,S}(d, convert(S, mat), chol)
end

MyPDMat(mat::AbstractMatrix) = MyPDMat(mat, cholesky(mat))

Base.:\(a::MyPDMat, x::AbstractVecOrMat) = cholesky(a) \ x
LinearAlgebra.cholesky(a::MyPDMat) = a.chol

h(x) = tr(MyPDMat(kernel(0.1)) \ kernel(x))
Zygote.gradient(x->h(only(x)), [.2]) # works

BTW, all of these functions can be differentiated with ForwardDiff.

@devmotion
Copy link
Member

This seems like a strange error.

Doesn't help but I've seen these quite often.

I tried to reproduce with my own type, but couldn't:

A major difference is that your type is not a subtype of AbstractMatrix, and hence defaults for AbstractMatrix in Zygote and ChainRules do not affect it.

@devmotion
Copy link
Member

The usual approach for fixing these errors is defining an rrule or a projector with CR, as Will discussed in the linked PR.

@simsurace
Copy link
Author

simsurace commented May 4, 2022

This is a sure sign that I don't actually understand how Zygote works. Where in the call stack would it matter whether or not MyPDMat is a subtype of AbstractMatrix? Because the only way this has any bearing on the result of h is through cholesky, which has a method for MyPDMat anyway.

EDIT: In other words, the fact that Zygote can differentiate h, and it suddenly can't just because we add extra information through making MyPDMat a subtype of AbstractMatrix, which seems to be irrelevant for the specific function call, is mysterious to me.

@simsurace
Copy link
Author

Making it a subtype of AbstractMatrix indeed makes it fail:

struct MyOtherPDMat{T<:Real,S<:AbstractMatrix} <: AbstractMatrix{T}
    dim::Int
    mat::S
    chol::Cholesky{T,S}

    MyOtherPDMat{T,S}(d::Int,m::AbstractMatrix{T},c::Cholesky{T,S}) where {T,S} = new{T,S}(d,m,c)
end

function MyOtherPDMat(mat::AbstractMatrix,chol::Cholesky{T,S}) where {T,S}
    d = size(mat, 1)
    size(chol, 1) == d ||
        throw(DimensionMismatch("Dimensions of mat and chol are inconsistent."))
    MyOtherPDMat{T,S}(d, convert(S, mat), chol)
end

MyOtherPDMat(mat::AbstractMatrix) = MyOtherPDMat(mat, cholesky(mat))

Base.:\(a::MyOtherPDMat, x::AbstractVecOrMat) = cholesky(a) \ x
LinearAlgebra.cholesky(a::MyOtherPDMat) = a.chol

h(x) = tr(MyOtherPDMat(kernel(0.1)) \ kernel(x))
Zygote.gradient(x->h(only(x)), [.2]) # ERROR
ERROR: MethodError: no method matching size(::MyOtherPDMat{Float64, Matrix{Float64}})
Closest candidates are:
  size(::AbstractArray{T, N}, ::Any) where {T, N} at ~/.julia/juliaup/julia-1.7.2+0~x64/share/julia/base/abstractarray.jl:42
  size(::Union{QR, LinearAlgebra.QRCompactWY, QRPivoted}) at ~/.julia/juliaup/julia-1.7.2+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/qr.jl:567
  size(::Union{QR, LinearAlgebra.QRCompactWY, QRPivoted}, ::Integer) at ~/.julia/juliaup/julia-1.7.2+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/qr.jl:566
  ...
Stacktrace:
  [1] axes
    @ ./abstractarray.jl:95 [inlined]
  [2] axes(A::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.7.2+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/adjtrans.jl:175
  [3] has_offset_axes(A::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}})
    @ Base ./abstractarray.jl:105
  [4] _tuple_any(f::typeof(Base.has_offset_axes), tf::Bool, a::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}}, b::Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Base ./tuple.jl:516
  [5] _tuple_any(f::Function, t::Tuple{Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}}, Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}})
    @ Base ./tuple.jl:513
  [6] has_offset_axes(::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}}, ::Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Base ./abstractarray.jl:107
  [7] require_one_based_indexing(::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}}, ::Vararg{Any})
    @ Base ./abstractarray.jl:110
  [8] \(A::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}}, B::Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.7.2+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/generic.jl:1129
  [9] (::Zygote.var"#752#753"{MyOtherPDMat{Float64, Matrix{Float64}}, Matrix{Float64}, Matrix{Float64}})(Z̄::Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/lib/array.jl:494
 [10] (::Zygote.var"#3058#back#754"{Zygote.var"#752#753"{MyOtherPDMat{Float64, Matrix{Float64}}, Matrix{Float64}, Matrix{Float64}}})(Δ::Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [11] Pullback
    @ ./REPL[21]:1 [inlined]
 [12] (::typeof(∂(h)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [13] Pullback
    @ ./REPL[22]:1 [inlined]
 [14] (::typeof(∂(#7)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [15] (::Zygote.var"#56#57"{typeof(∂(#7))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:41
 [16] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:76
 [17] top-level scope
    @ REPL[22]:1

@simsurace
Copy link
Author

The usual approach for fixing these errors is defining an rrule or a projector with CR, as Will discussed in the linked PR.

Which is the function that needs an rrule? Is it

PDMat(mat::AbstractMatrix) = PDMat(mat, cholesky(mat))

?

@oxinabox
Copy link
Contributor

oxinabox commented May 4, 2022

If it doesn't happen for your own type that doesn't subtype AbstractMatrix then there should be a solution that involved opting out of some problematic rrule.
https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/superpowers/opt_out.html

@simsurace
Copy link
Author

Thanks, this sounds great! Any tips on how to find out efficiently which rrule is problematic?

@simsurace
Copy link
Author

I wasn't able to figure out which rrule to opt out of. However, I came up with an rrule that allows me to differentiate through the rrule interface directly.

I started with det(PDMats(x)) because it is throwing the same error as the more complicated example involving matrix division.

However, Zygote does not seem to recognize my rrule and still complains about a missing adjoint:

using LinearAlgebra
using PDMats
using Zygote

x = [1. 0.2; 0.2 1.]
y = [1. 0.1; 0.1 1.]

Zygote.gradient(logdet  PDMat, x) |> only # works
Zygote.gradient(det  PDMat, x) |> only # ERROR
# ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64}

# Since `logdet` has an overload for `PDMat` and `det` doesn't, 
# and `logdet` above works while `det` doesn't, try to add an overload
# for `det` and see what happens:
LinearAlgebra.det(A::PDMat) = det(A.chol)
Zygote.gradient(det  PDMat, x) |> only
# ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64}

# While the overload may be useful for efficiency, it does not solve the AD issue
# Next, we will try to define an rrule
using ChainRules, ChainRulesCore

# rrule draft for constructor of PDMat
# This is probably not completely correct, but should be a good start
function ChainRulesCore.rrule(::Type{PDMat}, mat)
    chol, chol_pullback = rrule(cholesky, mat)
    y = PDMat(mat, chol)
    function PDMat_pullbackCR(m̄at::AbstractMatrix)
        @info "Using CR for PDMat, AbstractMatrix tangent"
        return NoTangent(), m̄at
    end
    function PDMat_pullbackCR(m̄at::Tangent)
        @info "Using CR for PDMat, Tangent type"
        return NoTangent(), chol_pullback(m̄at.chol)
    end
    return y, PDMat_pullbackCR
end

# Perform the individual forward and backward steps manually:
a, a_pullback = rrule(PDMat, x)
b, b_pullback = rrule(det, a)

b̄ = 1.
_, ā = b_pullback(b̄)
_, x̄ = a_pullback(ā)

# Compare to the result without the `PDMat` wrapper
unthunk(x̄)  Zygote.gradient(det, x) |> only # true
Zygote.gradient(det  PDMat, x) |> only
# ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64}

@oxinabox
Copy link
Contributor

oxinabox commented May 6, 2022

You often need to do a Zygote.refresh() to get it to pick up new rrules after it been used once.

@theogf
Copy link
Contributor

theogf commented May 7, 2022

Just tried this code, and Zygote.gradient(det ∘ PDMat, x) works but then Zygote.gradient(logdet ∘ PDMat, x) fails...

@theogf
Copy link
Contributor

theogf commented May 7, 2022

Interestingly ForwardDiff also returns the wrong thing if x is not a PDMat already:

julia> ForwardDiff.gradient(det  PDMat, x)
2×2 Matrix{Float64}:
 1.0  -0.4
 0.0   1.0

julia> ForwardDiff.gradient(det, PDMat(x))
2×2 Matrix{Float64}:
  1.0  -0.2
 -0.2   1.0
julia> ForwardDiff.gradient(det, x)
2×2 Matrix{Float64}:
  1.0  -0.2
 -0.2   1.0

@devmotion
Copy link
Member

devmotion commented May 7, 2022

I started with det(PDMats(x)) because it is throwing the same error as the more complicated example involving matrix division.

Probably one should restrict rrule(::typeof(det), ...) in the same way as the one for logdet: JuliaDiff/ChainRules.jl#245

Regardless of AD, I think it would be useful to add definitions of det(::AbstractPDMat) = .... since otherwise these calls will fall back to the generic definitions based on the LU decomposition in LinearAlgebra: https://github.com/JuliaLang/julia/blob/6e061322438f13c6548200f115f3c31b20860a30/stdlib/LinearAlgebra/src/generic.jl#L1598-L1604

Probably logdet(::PDMat) does not error because Zygote defines an adjoint for logdet(::Cholesky) (should be moved to ChainRules I guess together with a rule for det(::Cholesky)): https://github.com/FluxML/Zygote.jl/blob/a392eabdc0217f2f34d77ce19d4167c3cd4abbcf/src/lib/array.jl#L744-L748 However, similar to the ForwardDiff example it seems the derivative is wrong since it does not return a Hermitian but a triangular matrix as gradient.

I think the right approach would be to add a projection mechanism for PDMat and rrules for the constructor similar to the one for Hermitian and Symmetric matrices (https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2d75b4be102bb41ba3ac6df6dec8bb9617b20f0f/src/projection.jl#L425-L451 and https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/symmetric.jl#L5-L92).

@simsurace
Copy link
Author

Interestingly ForwardDiff also returns the wrong thing if x is not a PDMat already:

I think this would be a great test to do for all the functions overloaded for PDMat. Basically, the wrapper should not change any derivatives wrt. x.

@devmotion
Copy link
Member

Regardless of AD, I think it would be useful to add definitions of det(::AbstractPDMat) = .... since otherwise these calls will fall back to the generic definitions based on the LU decomposition in LinearAlgebra: https://github.com/JuliaLang/julia/blob/6e061322438f13c6548200f115f3c31b20860a30/stdlib/LinearAlgebra/src/generic.jl#L1598-L1604

I opened #161.

@devmotion
Copy link
Member

devmotion commented May 9, 2022

The ForwardDiff issue is not related to PDMats:

julia> using PDMats, ForwardDiff, LinearAlgebra

julia> x = [1. 0.2; 0.2 1.];

julia> ForwardDiff.gradient(det  PDMat, x)
2×2 Matrix{Float64}:
 1.0  -0.4
 0.0   1.0

julia> ForwardDiff.gradient(det  cholesky, x)
2×2 Matrix{Float64}:
 1.0  -0.4
 0.0   1.0

From the perspective of det etc., PDMat does not wrap x but cholesky(x), so the comparison with ForwardDiff.gradient(det, x) is not correct. Thus the incorrect derivative is caused by cholesky and actually expected.

@simsurace
Copy link
Author

So cholesky is not defined for non-symmetric x, which means that a gradient step will lead to the subspace being left. This is an instance of the gradient being correct as a differential, i.e. for any choice of tangent vector the differential applied to that vector gives the correct result, but it is not a tangent vector itself. Should I open an issue in DiffRules.jl?

@simsurace
Copy link
Author

Especially because using a Symmetric wrapper fails:

julia> ForwardDiff.gradient(logdet  PDMat, Symmetric(x))
ERROR: ArgumentError: Cannot set a non-diagonal index in a symmetric matrix

whereas it works for Zygote:

julia> Zygote.gradient(logdet  PDMat, Symmetric(x)) |> only
2×2 Symmetric{Float64, Matrix{Float64}}:
  1.04167   -0.208333
 -0.208333   1.04167

@devmotion
Copy link
Member

Actually, I don't think there's anything wrong with the derivatives of ForwardDiff and cholesky, the forward-mode sensitivities of cholesky are correct (compare e.g. with https://arxiv.org/abs/1602.07527). It's just that ForwardDiff.gradient(det \circ cholesky, x) is fundamentally different from ForwardDiff.gradient(det, x): the first one assumes that x AND the sensitivities/perturbations/... dx are symmetric matrices whereas the second one does not use any of these assumptions.

@devmotion
Copy link
Member

Especially because using a Symmetric wrapper fails:

This is a general issue with ForwardDiff.seed!, not limited to Symmetric and not related to cholesky.

@simsurace
Copy link
Author

Yes, as I said I do believe that the differential is correct because it is only defined for a symmetric tangent. I think we can focus on getting the gradients working in Zygote.

@devmotion
Copy link
Member

The only things to make it work with Zygote are (copied from above):

Probably one should restrict rrule(::typeof(det), ...) in the same way as the one for logdet: JuliaDiff/ChainRules.jl#245
Zygote defines an adjoint for logdet(::Cholesky) (should be moved to ChainRules I guess together with a rule for det(::Cholesky)): https://github.com/FluxML/Zygote.jl/blob/a392eabdc0217f2f34d77ce19d4167c3cd4abbcf/src/lib/array.jl#L744-L748

I just checked it locally, with these changes also det works. I'll open a PR in ChainRules.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants