From e24dfe4b5fa64018d582e2522b987df4c0dd7751 Mon Sep 17 00:00:00 2001 From: Jutho Date: Sat, 9 Nov 2024 09:49:18 +0100 Subject: [PATCH] fix arnoldi ad rule for degenerate eigsolve (#99) * fix arnoldi ad rule for degenerate eigsolve * another fix attempt * the problem is with finite difference * cleanup --- ext/KrylovKitChainRulesCoreExt/eigsolve.jl | 35 +- test/ad.jl | 890 --------------------- test/ad/degenerateeigsolve.jl | 169 ++++ test/ad/eigsolve.jl | 391 +++++++++ test/ad/linsolve.jl | 129 +++ test/ad/svdsolve.jl | 368 +++++++++ test/runtests.jl | 5 +- 7 files changed, 1084 insertions(+), 903 deletions(-) delete mode 100644 test/ad.jl create mode 100644 test/ad/degenerateeigsolve.jl create mode 100644 test/ad/eigsolve.jl create mode 100644 test/ad/linsolve.jl create mode 100644 test/ad/svdsolve.jl diff --git a/ext/KrylovKitChainRulesCoreExt/eigsolve.jl b/ext/KrylovKitChainRulesCoreExt/eigsolve.jl index bb14a0b..ab1dcab 100644 --- a/ext/KrylovKitChainRulesCoreExt/eigsolve.jl +++ b/ext/KrylovKitChainRulesCoreExt/eigsolve.jl @@ -243,11 +243,15 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, # [(A * (1-P) + shift * P) -ΔV; 0 Λ], where eᵢ is the ith unit vector. We will need # to renormalise the eigenvectors to have exactly eᵢ as second component. We use # (0, e₁ + e₂ + ... + eₙ) as the initial guess for the eigenvalue problem. + W₀ = (zerovector(vecs[1]), one.(vals)) P = orthogonalprojector(vecs, n, Gc) + # TODO: is `realeigsolve` every used here, as there is a separate `alg_primal::Lanczos` method below solver = (T <: Real) ? KrylovKit.realeigsolve : KrylovKit.eigsolve # for `eigsolve`, `T` will always be a Complex subtype` - rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg, shift = shift - solver(W₀, n, reverse_which(which), alg_rrule) do (w, x) + rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg, shift = shift, + eigsort = EigSorter(v -> minimum(DistanceTo(conj(v)), vals)) + + solver(W₀, n, eigsort, alg_rrule) do (w, x) w₀ = P(w) w′ = KrylovKit.apply(fᴴ, add(w, w₀, -1)) if !iszero(shift) @@ -268,20 +272,25 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, tol = alg_rrule.tol Q = orthogonalcomplementprojector(vecs, n, Gc) for i in 1:n - w, x = Ws[i] - _, ic = findmax(abs, x) - factor = 1 / x[ic] - x[ic] = zero(x[ic]) + d, ic = findmin(DistanceTo(conj(vals[i])), rvals) + w, x = Ws[ic] + factor = 1 / x[i] + x[i] = zero(x[i]) if alg_rrule.verbosity >= 0 - error = max(norm(x, Inf), abs(rvals[i] - conj(vals[ic]))) - error > 5 * tol && - @warn "`eigsolve` cotangent linear problem ($ic) returns unexpected result: error = $error" + error = max(norm(x, Inf), abs(rvals[ic] - conj(vals[i]))) + error > 10 * tol && + @warn "`eigsolve` cotangent linear problem ($i) returns unexpected result: error = $error" end - ws[ic] = VectorInterface.add!!(zs[ic], Q(w), -factor) + ws[i] = VectorInterface.add!!(zs[i], Q(w), -factor) end return ws end +struct DistanceTo{T} + x::T +end +(d::DistanceTo)(y) = norm(y - d.x) + # several simplications happen in the case of a Hermitian eigenvalue problem function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ, alg_primal::Lanczos, alg_rrule::Arnoldi) @@ -342,8 +351,10 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, W₀ = (zerovector(vecs[1]), one.(vals)) P = orthogonalprojector(vecs, n) solver = (T <: Real) ? KrylovKit.realeigsolve : KrylovKit.eigsolve - rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg, shift = shift - solver(W₀, n, reverse_which(which), alg_rrule) do (w, x) + rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg, shift = shift, + eigsort = EigSorter(v -> minimum(DistanceTo(conj(v)), vals)) + + solver(W₀, n, eigsort, alg_rrule) do (w, x) w₀ = P(w) w′ = KrylovKit.apply(fᴴ, add(w, w₀, -1)) if !iszero(shift) diff --git a/test/ad.jl b/test/ad.jl deleted file mode 100644 index 20c82ba..0000000 --- a/test/ad.jl +++ /dev/null @@ -1,890 +0,0 @@ -module LinsolveAD -using KrylovKit, LinearAlgebra -using Random, Test, TestExtras -using ChainRulesCore, ChainRulesTestUtils, Zygote, FiniteDifferences - -fdm = ChainRulesTestUtils._fdm -n = 10 -N = 30 - -function build_mat_example(A, b, x, alg, alg_rrule) - Avec, A_fromvec = to_vec(A) - bvec, b_fromvec = to_vec(b) - xvec, x_fromvec = to_vec(x) - T = eltype(A) - - function mat_example(Av, bv, xv) - à = A_fromvec(Av) - b̃ = b_fromvec(bv) - x̃ = x_fromvec(xv) - x, info = linsolve(Ã, b̃, x̃, alg; alg_rrule=alg_rrule) - if info.converged == 0 - @warn "linsolve did not converge:" - println("normres = ", info.normres) - end - xv, = to_vec(x) - return xv - end - function mat_example_fun(Av, bv, xv) - à = A_fromvec(Av) - b̃ = b_fromvec(bv) - x̃ = x_fromvec(xv) - f = x -> à * x - x, info = linsolve(f, b̃, x̃, alg; alg_rrule=alg_rrule) - if info.converged == 0 - @warn "linsolve did not converge:" - println("normres = ", info.normres) - end - xv, = to_vec(x) - return xv - end - return mat_example, mat_example_fun, Avec, bvec, xvec -end - -function build_fun_example(A, b, c, d, e, f, alg, alg_rrule) - Avec, matfromvec = to_vec(A) - bvec, vecfromvec = to_vec(b) - cvec, = to_vec(c) - dvec, = to_vec(d) - evec, scalarfromvec = to_vec(e) - fvec, = to_vec(f) - - function fun_example(Av, bv, cv, dv, ev, fv) - à = matfromvec(Av) - b̃ = vecfromvec(bv) - c̃ = vecfromvec(cv) - d̃ = vecfromvec(dv) - ẽ = scalarfromvec(ev) - f̃ = scalarfromvec(fv) - - x, info = linsolve(b̃, zero(b̃), alg, ẽ, f̃; alg_rrule=alg_rrule) do y - return à * y + c̃ * dot(d̃, y) - end - # info.converged > 0 || @warn "not converged" - xv, = to_vec(x) - return xv - end - return fun_example, Avec, bvec, cvec, dvec, evec, fvec -end - -@testset "Small linsolve AD test with eltype=$T" for T in (Float32, Float64, ComplexF32, - ComplexF64) - A = 2 * (rand(T, (n, n)) .- one(T) / 2) - b = 2 * (rand(T, n) .- one(T) / 2) - b /= norm(b) - x = 2 * (rand(T, n) .- one(T) / 2) - - condA = cond(A) - tol = condA * (T <: Real ? eps(T) : 4 * eps(real(T))) - alg = GMRES(; tol=tol, krylovdim=n, maxiter=1) - - config = Zygote.ZygoteRuleConfig() - _, pb = ChainRulesCore.rrule(config, linsolve, A, b, x, alg, 0, 1; alg_rrule=alg) - @constinferred pb((ZeroTangent(), NoTangent())) - @constinferred pb((rand(T, n), NoTangent())) - - mat_example, mat_example_fun, Avec, bvec, xvec = build_mat_example(A, b, x, alg, alg) - (JA, Jb, Jx) = FiniteDifferences.jacobian(fdm, mat_example, Avec, bvec, xvec) - (JA1, Jb1, Jx1) = Zygote.jacobian(mat_example, Avec, bvec, xvec) - (JA2, Jb2, Jx2) = Zygote.jacobian(mat_example_fun, Avec, bvec, xvec) - - @test isapprox(JA, JA1; rtol=condA * sqrt(eps(real(T)))) - @test all(isapprox.(JA1, JA2; atol=n * eps(real(T)))) - # factor 2 is minimally necessary for complex case, but 3 is more robust - @test norm(Jx, Inf) < condA * sqrt(eps(real(T))) - @test all(iszero, Jx1) -end - -@testset "Large linsolve AD test with eltype=$T" for T in (Float64, ComplexF64) - A = rand(T, (N, N)) .- one(T) / 2 - A = I - (9 // 10) * A / maximum(abs, eigvals(A)) - b = 2 * (rand(T, N) .- one(T) / 2) - c = 2 * (rand(T, N) .- one(T) / 2) - d = 2 * (rand(T, N) .- one(T) / 2) - e = rand(T) - f = rand(T) - - # mix algorithms] - tol = N^2 * eps(real(T)) - alg1 = GMRES(; tol=tol, krylovdim=20) - alg2 = BiCGStab(; tol=tol, maxiter=100) # BiCGStab seems to require slightly smaller tolerance for tests to work - for (alg, alg_rrule) in ((alg1, alg2), (alg2, alg1)) - fun_example, Avec, bvec, cvec, dvec, evec, fvec = build_fun_example(A, b, c, d, e, - f, alg, - alg_rrule) - - (JA, Jb, Jc, Jd, Je, Jf) = FiniteDifferences.jacobian(fdm, fun_example, - Avec, bvec, cvec, dvec, evec, - fvec) - (JA′, Jb′, Jc′, Jd′, Je′, Jf′) = Zygote.jacobian(fun_example, Avec, bvec, cvec, - dvec, evec, fvec) - @test JA ≈ JA′ - @test Jb ≈ Jb′ - @test Jc ≈ Jc′ - @test Jd ≈ Jd′ - @test Je ≈ Je′ - @test Jf ≈ Jf′ - end -end -end - -module EigsolveAD -using KrylovKit, LinearAlgebra -using Random, Test, TestExtras -using ChainRulesCore, ChainRulesTestUtils, Zygote, FiniteDifferences -Random.seed!(987654321) - -fdm = ChainRulesTestUtils._fdm -n = 10 -N = 30 - -function build_mat_example(A, x, howmany::Int, which, alg, alg_rrule) - Avec, A_fromvec = to_vec(A) - xvec, x_fromvec = to_vec(x) - - vals, vecs, info = eigsolve(A, x, howmany, which, alg) - info.converged < howmany && @warn "eigsolve did not converge" - if eltype(A) <: Real && length(vals) > howmany && - vals[howmany] == conj(vals[howmany + 1]) - howmany += 1 - end - - function mat_example(Av, xv) - à = A_fromvec(Av) - x̃ = x_fromvec(xv) - vals′, vecs′, info′ = eigsolve(Ã, x̃, howmany, which, alg; alg_rrule=alg_rrule) - info′.converged < howmany && @warn "eigsolve did not converge" - catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - - function mat_example_fun(Av, xv) - à = A_fromvec(Av) - x̃ = x_fromvec(xv) - f = x -> à * x - vals′, vecs′, info′ = eigsolve(f, x̃, howmany, which, alg; alg_rrule=alg_rrule) - info′.converged < howmany && @warn "eigsolve did not converge" - catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - - function mat_example_fd(Av, xv) - à = A_fromvec(Av) - x̃ = x_fromvec(xv) - vals′, vecs′, info′ = eigsolve(Ã, x̃, howmany, which, alg; alg_rrule=alg_rrule) - info′.converged < howmany && @warn "eigsolve did not converge" - for i in 1:howmany - d = dot(vecs[i], vecs′[i]) - @assert abs(d) > sqrt(eps(real(eltype(A)))) - phasefix = abs(d) / d - vecs′[i] = vecs′[i] * phasefix - end - catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - - return mat_example, mat_example_fun, mat_example_fd, Avec, xvec, vals, vecs, howmany -end - -function build_fun_example(A, x, c, d, howmany::Int, which, alg, alg_rrule) - Avec, matfromvec = to_vec(A) - xvec, vecfromvec = to_vec(x) - cvec, = to_vec(c) - dvec, = to_vec(d) - - vals, vecs, info = eigsolve(x, howmany, which, alg) do y - return A * y + c * dot(d, y) - end - info.converged < howmany && @warn "eigsolve did not converge" - if eltype(A) <: Real && length(vals) > howmany && - vals[howmany] == conj(vals[howmany + 1]) - howmany += 1 - end - - fun_example_ad = let howmany′ = howmany - function (Av, xv, cv, dv) - à = matfromvec(Av) - x̃ = vecfromvec(xv) - c̃ = vecfromvec(cv) - d̃ = vecfromvec(dv) - - vals′, vecs′, info′ = eigsolve(x̃, howmany′, which, alg; - alg_rrule=alg_rrule) do y - return à * y + c̃ * dot(d̃, y) - end - info′.converged < howmany′ && @warn "eigsolve did not converge" - catresults = vcat(vals′[1:howmany′], vecs′[1:howmany′]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - end - - fun_example_fd = let howmany′ = howmany - function (Av, xv, cv, dv) - à = matfromvec(Av) - x̃ = vecfromvec(xv) - c̃ = vecfromvec(cv) - d̃ = vecfromvec(dv) - - vals′, vecs′, info′ = eigsolve(x̃, howmany′, which, alg; - alg_rrule=alg_rrule) do y - return à * y + c̃ * dot(d̃, y) - end - info′.converged < howmany′ && @warn "eigsolve did not converge" - for i in 1:howmany′ - d = dot(vecs[i], vecs′[i]) - @assert abs(d) > sqrt(eps(real(eltype(A)))) - phasefix = abs(d) / d - vecs′[i] = vecs′[i] * phasefix - end - catresults = vcat(vals′[1:howmany′], vecs′[1:howmany′]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - end - - return fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany -end - -function build_hermitianfun_example(A, x, c, howmany::Int, which, alg, alg_rrule) - Avec, matfromvec = to_vec(A) - xvec, xvecfromvec = to_vec(x) - cvec, cvecfromvec = to_vec(c) - - vals, vecs, info = eigsolve(x, howmany, which, alg) do y - return Hermitian(A) * y + c * dot(c, y) - end - info.converged < howmany && @warn "eigsolve did not converge" - - function fun_example(Av, xv, cv) - à = matfromvec(Av) - x̃ = xvecfromvec(xv) - c̃ = cvecfromvec(cv) - - vals′, vecs′, info′ = eigsolve(x̃, howmany, which, alg; - alg_rrule=alg_rrule) do y - return Hermitian(Ã) * y + c̃ * dot(c̃, y) - end - info′.converged < howmany && @warn "eigsolve did not converge" - catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - - function fun_example_fd(Av, xv, cv) - à = matfromvec(Av) - x̃ = xvecfromvec(xv) - c̃ = cvecfromvec(cv) - - vals′, vecs′, info′ = eigsolve(x̃, howmany, which, alg; - alg_rrule=alg_rrule) do y - return Hermitian(Ã) * y + c̃ * dot(c̃, y) - end - info′.converged < howmany && @warn "eigsolve did not converge" - for i in 1:howmany - d = dot(vecs[i], vecs′[i]) - @assert abs(d) > sqrt(eps(real(eltype(A)))) - phasefix = abs(d) / d - vecs′[i] = vecs′[i] * phasefix - end - catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - - return fun_example, fun_example_fd, Avec, xvec, cvec, vals, vecs, howmany -end - -@timedtestset "Small eigsolve AD test for eltype=$T" for T in - (Float32, Float64, ComplexF32, - ComplexF64) - if T <: Complex - whichlist = (:LM, :SR, :LR, :SI, :LI) - else - whichlist = (:LM, :SR, :LR) - end - A = 2 * (rand(T, (n, n)) .- one(T) / 2) - x = 2 * (rand(T, n) .- one(T) / 2) - x /= norm(x) - - howmany = 3 - condA = cond(A) - tol = n * condA * (T <: Real ? eps(T) : 4 * eps(real(T))) - alg = Arnoldi(; tol=tol, krylovdim=n) - alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1) - alg_rrule2 = GMRES(; tol=tol, krylovdim=n + 1, verbosity=-1) - config = Zygote.ZygoteRuleConfig() - @testset for which in whichlist - for alg_rrule in (alg_rrule1, alg_rrule2) - # unfortunately, rrule does not seem type stable for function arguments, because the - # `rrule_via_ad` call does not produce type stable `rrule`s for the function - (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, - which, alg; alg_rrule=alg_rrule) - # NOTE: the following is not necessary here, as it is corrected for in the `eigsolve` rrule - # if length(vals) > howmany && vals[howmany] == conj(vals[howmany + 1]) - # howmany += 1 - # end - @constinferred pb((ZeroTangent(), ZeroTangent(), NoTangent())) - @constinferred pb((randn(T, howmany), ZeroTangent(), NoTangent())) - @constinferred pb((randn(T, howmany), [randn(T, n)], NoTangent())) - @constinferred pb((randn(T, howmany), [randn(T, n) for _ in 1:howmany], - NoTangent())) - end - - for alg_rrule in (alg_rrule1, alg_rrule2) - mat_example, mat_example_fun, mat_example_fd, Avec, xvec, vals, vecs, howmany = build_mat_example(A, - x, - howmany, - which, - alg, - alg_rrule) - - (JA, Jx) = FiniteDifferences.jacobian(fdm, mat_example_fd, Avec, xvec) - (JA1, Jx1) = Zygote.jacobian(mat_example, Avec, xvec) - (JA2, Jx2) = Zygote.jacobian(mat_example_fun, Avec, xvec) - - # finite difference comparison using some kind of tolerance heuristic - @test isapprox(JA, JA1; rtol=condA * sqrt(eps(real(T)))) - @test all(isapprox.(JA1, JA2; atol=n * eps(real(T)))) - @test norm(Jx, Inf) < condA * sqrt(eps(real(T))) - @test all(iszero, Jx1) - @test all(iszero, Jx2) - - # some analysis - ∂vals = complex.(JA1[1:howmany, :], JA1[howmany * (n + 1) .+ (1:howmany), :]) - ∂vecs = map(1:howmany) do i - return complex.(JA1[(howmany + (i - 1) * n) .+ (1:n), :], - JA1[(howmany * (n + 2) + (i - 1) * n) .+ (1:n), :]) - end - if eltype(A) <: Complex # test holomorphicity / Cauchy-Riemann equations - # for eigenvalues - @test real(∂vals[:, 1:2:(2n^2)]) ≈ +imag(∂vals[:, 2:2:(2n^2)]) - @test imag(∂vals[:, 1:2:(2n^2)]) ≈ -real(∂vals[:, 2:2:(2n^2)]) - # and for eigenvectors - for i in 1:howmany - @test real(∂vecs[i][:, 1:2:(2n^2)]) ≈ +imag(∂vecs[i][:, 2:2:(2n^2)]) - @test imag(∂vecs[i][:, 1:2:(2n^2)]) ≈ -real(∂vecs[i][:, 2:2:(2n^2)]) - end - end - # test orthogonality of vecs and ∂vecs - for i in 1:howmany - @test all(isapprox.(abs.(vecs[i]' * ∂vecs[i]), 0; atol=sqrt(eps(real(T))))) - end - end - end - - if T <: Complex - @testset "test warnings and info" begin - alg_rrule = Arnoldi(; tol=tol, krylovdim=n, verbosity=-1) - (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, - :LR, alg; alg_rrule=alg_rrule) - @test_logs pb((ZeroTangent(), im .* vecs[1:2] .+ vecs[2:-1:1], NoTangent())) - - alg_rrule = Arnoldi(; tol=tol, krylovdim=n, verbosity=0) - (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, - :LR, alg; alg_rrule=alg_rrule) - @test_logs (:warn,) pb((ZeroTangent(), im .* vecs[1:2] .+ vecs[2:-1:1], - NoTangent())) - pbs = @test_logs pb((ZeroTangent(), vecs[1:2], NoTangent())) - @test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T))) - - alg_rrule = Arnoldi(; tol=tol, krylovdim=n, verbosity=1) - (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, - :LR, alg; alg_rrule=alg_rrule) - @test_logs (:warn,) (:info,) pb((ZeroTangent(), im .* vecs[1:2] .+ vecs[2:-1:1], - NoTangent())) - pbs = @test_logs (:info,) pb((ZeroTangent(), vecs[1:2], NoTangent())) - @test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T))) - - alg_rrule = GMRES(; tol=tol, krylovdim=n, verbosity=-1) - (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, - :LR, alg; alg_rrule=alg_rrule) - @test_logs pb((ZeroTangent(), im .* vecs[1:2] .+ vecs[2:-1:1], NoTangent())) - - alg_rrule = GMRES(; tol=tol, krylovdim=n, verbosity=0) - (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, - :LR, alg; alg_rrule=alg_rrule) - @test_logs (:warn,) (:warn,) pb((ZeroTangent(), - im .* vecs[1:2] .+ - vecs[2:-1:1], - NoTangent())) - pbs = @test_logs pb((ZeroTangent(), vecs[1:2], NoTangent())) - @test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T))) - - alg_rrule = GMRES(; tol=tol, krylovdim=n, verbosity=1) - (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, - :LR, alg; alg_rrule=alg_rrule) - @test_logs (:warn,) (:info,) (:warn,) (:info,) pb((ZeroTangent(), - im .* vecs[1:2] .+ - vecs[2:-1:1], - NoTangent())) - pbs = @test_logs (:info,) (:info,) pb((ZeroTangent(), vecs[1:2], NoTangent())) - @test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T))) - end - end -end -@timedtestset "Large eigsolve AD test with eltype=$T" for T in (Float64, ComplexF64) - if T <: Complex - whichlist = (:LM, :SI) - else - whichlist = (:LM, :SR) - end - @testset for which in whichlist - A = rand(T, (N, N)) .- one(T) / 2 - A = I - (9 // 10) * A / maximum(abs, eigvals(A)) - x = 2 * (rand(T, N) .- one(T) / 2) - x /= norm(x) - c = 2 * (rand(T, N) .- one(T) / 2) - d = 2 * (rand(T, N) .- one(T) / 2) - - howmany = 2 - tol = 2 * N^2 * eps(real(T)) - alg = Arnoldi(; tol=tol, krylovdim=2n) - alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1) - alg_rrule2 = GMRES(; tol=tol, krylovdim=2n, verbosity=-1) - @testset for alg_rrule in (alg_rrule1, alg_rrule2) - fun_example, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany = build_fun_example(A, - x, - c, - d, - howmany, - which, - alg, - alg_rrule) - - (JA, Jx, Jc, Jd) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec, - cvec, dvec) - (JA′, Jx′, Jc′, Jd′) = Zygote.jacobian(fun_example, Avec, xvec, cvec, dvec) - @test JA ≈ JA′ - @test Jc ≈ Jc′ - @test Jd ≈ Jd′ - end - end -end -@timedtestset "Large Hermitian eigsolve AD test with eltype=$T" for T in - (Float64, ComplexF64) - whichlist = (:LR, :SR) - @testset for which in whichlist - A = rand(T, (N, N)) .- one(T) / 2 - A = I - (9 // 10) * A / maximum(abs, eigvals(A)) - x = 2 * (rand(T, N) .- one(T) / 2) - x /= norm(x) - c = 2 * (rand(T, N) .- one(T) / 2) - - howmany = 2 - tol = 2 * N^2 * eps(real(T)) - alg = Lanczos(; tol=tol, krylovdim=2n) - alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1) - alg_rrule2 = GMRES(; tol=tol, krylovdim=2n, verbosity=-1) - @testset for alg_rrule in (alg_rrule1, alg_rrule2) - fun_example, fun_example_fd, Avec, xvec, cvec, vals, vecs, howmany = build_hermitianfun_example(A, - x, - c, - howmany, - which, - alg, - alg_rrule) - - (JA, Jx, Jc) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec, - cvec) - (JA′, Jx′, Jc′) = Zygote.jacobian(fun_example, Avec, xvec, cvec) - @test JA ≈ JA′ - @test Jc ≈ Jc′ - end - end -end - -end - -module SvdsolveAD -using KrylovKit, LinearAlgebra -using Random, Test, TestExtras -using ChainRulesCore, ChainRulesTestUtils, Zygote, FiniteDifferences -Random.seed!(123456789) - -fdm = ChainRulesTestUtils._fdm -n = 10 -N = 30 - -function build_mat_example(A, x, howmany::Int, alg, alg_rrule) - Avec, A_fromvec = to_vec(A) - xvec, x_fromvec = to_vec(x) - - vals, lvecs, rvecs, info = svdsolve(A, x, howmany, :LR, alg) - info.converged < howmany && @warn "svdsolve did not converge" - - function mat_example_mat(Av, xv) - à = A_fromvec(Av) - x̃ = x_fromvec(xv) - vals′, lvecs′, rvecs′, info′ = svdsolve(Ã, x̃, howmany, :LR, alg; - alg_rrule=alg_rrule) - info′.converged < howmany && @warn "svdsolve did not converge" - catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - function mat_example_fval(Av, xv) - à = A_fromvec(Av) - x̃ = x_fromvec(xv) - f = (x, adj::Val) -> (adj isa Val{true}) ? adjoint(Ã) * x : à * x - vals′, lvecs′, rvecs′, info′ = svdsolve(f, x̃, howmany, :LR, alg; - alg_rrule=alg_rrule) - info′.converged < howmany && @warn "svdsolve did not converge" - catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - function mat_example_ftuple(Av, xv) - à = A_fromvec(Av) - x̃ = x_fromvec(xv) - (f, fᴴ) = (x -> à * x, x -> adjoint(Ã) * x) - vals′, lvecs′, rvecs′, info′ = svdsolve((f, fᴴ), x̃, howmany, :LR, alg; - alg_rrule=alg_rrule) - info′.converged < howmany && @warn "svdsolve did not converge" - catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - - function mat_example_fd(Av, xv) - à = A_fromvec(Av) - x̃ = x_fromvec(xv) - vals′, lvecs′, rvecs′, info′ = svdsolve(Ã, x̃, howmany, :LR, alg; - alg_rrule=alg_rrule) - info′.converged < howmany && @warn "svdsolve did not converge" - for i in 1:howmany - dl = dot(lvecs[i], lvecs′[i]) - dr = dot(rvecs[i], rvecs′[i]) - @assert abs(dl) > sqrt(eps(real(eltype(A)))) - @assert abs(dr) > sqrt(eps(real(eltype(A)))) - phasefix = sqrt(abs(dl * dr) / (dl * dr)) - lvecs′[i] = lvecs′[i] * phasefix - rvecs′[i] = rvecs′[i] * phasefix - end - catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - - return mat_example_mat, mat_example_ftuple, mat_example_fval, mat_example_fd, Avec, - xvec, vals, lvecs, rvecs -end - -function build_fun_example(A, x, c, d, howmany::Int, alg, alg_rrule) - Avec, matfromvec = to_vec(A) - xvec, xvecfromvec = to_vec(x) - cvec, cvecfromvec = to_vec(c) - dvec, dvecfromvec = to_vec(d) - - f = y -> A * y + c * dot(d, y) - fᴴ = y -> adjoint(A) * y + d * dot(c, y) - vals, lvecs, rvecs, info = svdsolve((f, fᴴ), x, howmany, :LR, alg) - info.converged < howmany && @warn "svdsolve did not converge" - - function fun_example_ad(Av, xv, cv, dv) - à = matfromvec(Av) - x̃ = xvecfromvec(xv) - c̃ = cvecfromvec(cv) - d̃ = dvecfromvec(dv) - - f = y -> à * y + c̃ * dot(d̃, y) - fᴴ = y -> adjoint(Ã) * y + d̃ * dot(c̃, y) - vals′, lvecs′, rvecs′, info′ = svdsolve((f, fᴴ), x̃, howmany, :LR, alg; - alg_rrule=alg_rrule) - info′.converged < howmany && @warn "svdsolve did not converge" - catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - function fun_example_fd(Av, xv, cv, dv) - à = matfromvec(Av) - x̃ = xvecfromvec(xv) - c̃ = cvecfromvec(cv) - d̃ = dvecfromvec(dv) - - f = y -> à * y + c̃ * dot(d̃, y) - fᴴ = y -> adjoint(Ã) * y + d̃ * dot(c̃, y) - vals′, lvecs′, rvecs′, info′ = svdsolve((f, fᴴ), x̃, howmany, :LR, alg; - alg_rrule=alg_rrule) - info′.converged < howmany && @warn "svdsolve did not converge" - for i in 1:howmany - dl = dot(lvecs[i], lvecs′[i]) - dr = dot(rvecs[i], rvecs′[i]) - @assert abs(dl) > sqrt(eps(real(eltype(A)))) - @assert abs(dr) > sqrt(eps(real(eltype(A)))) - phasefix = sqrt(abs(dl * dr) / (dl * dr)) - lvecs′[i] = lvecs′[i] * phasefix - rvecs′[i] = rvecs′[i] * phasefix - end - catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) - if eltype(catresults) <: Complex - return vcat(real(catresults), imag(catresults)) - else - return catresults - end - end - return fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, lvecs, rvecs -end - -@timedtestset "Small svdsolve AD test with eltype=$T" for T in - (Float32, Float64, ComplexF32, - ComplexF64) - A = 2 * (rand(T, (n, 2 * n)) .- one(T) / 2) - x = 2 * (rand(T, n) .- one(T) / 2) - x /= norm(x) - condA = cond(A) - - howmany = 3 - tol = 3 * n * condA * (T <: Real ? eps(T) : 4 * eps(real(T))) - alg = GKL(; krylovdim=2n, tol=tol) - alg_rrule1 = Arnoldi(; tol=tol, krylovdim=4n, verbosity=-1) - alg_rrule2 = GMRES(; tol=tol, krylovdim=3n, verbosity=-1) - config = Zygote.ZygoteRuleConfig() - for alg_rrule in (alg_rrule1, alg_rrule2) - # unfortunately, rrule does not seem type stable for function arguments, because the - # `rrule_via_ad` call does not produce type stable `rrule`s for the function - _, pb = ChainRulesCore.rrule(config, svdsolve, A, x, howmany, :LR, alg; - alg_rrule=alg_rrule) - @constinferred pb((ZeroTangent(), ZeroTangent(), ZeroTangent(), NoTangent())) - @constinferred pb((randn(real(T), howmany), ZeroTangent(), ZeroTangent(), - NoTangent())) - @constinferred pb((randn(real(T), howmany), [randn(T, n)], ZeroTangent(), - NoTangent())) - @constinferred pb((randn(real(T), howmany), [randn(T, n) for _ in 1:howmany], - [randn(T, 2 * n) for _ in 1:howmany], NoTangent())) - end - for alg_rrule in (alg_rrule1, alg_rrule2) - (mat_example_mat, mat_example_ftuple, mat_example_fval, mat_example_fd, - Avec, xvec, vals, lvecs, rvecs) = build_mat_example(A, x, howmany, alg, alg_rrule) - - (JA, Jx) = FiniteDifferences.jacobian(fdm, mat_example_fd, Avec, xvec) - (JA1, Jx1) = Zygote.jacobian(mat_example_mat, Avec, xvec) - (JA2, Jx2) = Zygote.jacobian(mat_example_fval, Avec, xvec) - (JA3, Jx3) = Zygote.jacobian(mat_example_ftuple, Avec, xvec) - - # finite difference comparison using some kind of tolerance heuristic - @test isapprox(JA, JA1; rtol=3 * n * n * condA * sqrt(eps(real(T)))) - @test all(isapprox.(JA1, JA2; atol=n * eps(real(T)))) - @test all(isapprox.(JA1, JA3; atol=n * eps(real(T)))) - @test norm(Jx, Inf) < 5 * condA * sqrt(eps(real(T))) - @test all(iszero, Jx1) - @test all(iszero, Jx2) - @test all(iszero, Jx3) - - # some analysis - if eltype(A) <: Complex # test holomorphicity / Cauchy-Riemann equations - ∂vals = complex.(JA1[1:howmany, :], - JA1[howmany * (3 * n + 1) .+ (1:howmany), :]) - ∂lvecs = map(1:howmany) do i - return complex.(JA1[(howmany + (i - 1) * n) .+ (1:n), :], - JA1[(howmany * (3 * n + 2) + (i - 1) * n) .+ (1:n), :]) - end - ∂rvecs = map(1:howmany) do i - return complex.(JA1[(howmany * (n + 1) + (i - 1) * (2 * n)) .+ (1:(2n)), :], - JA1[(howmany * (4 * n + 2) + (i - 1) * 2n) .+ (1:(2n)), :]) - end - else - ∂vals = JA1[1:howmany, :] - ∂lvecs = map(1:howmany) do i - return JA1[(howmany + (i - 1) * n) .+ (1:n), :] - end - ∂rvecs = map(1:howmany) do i - return JA1[(howmany * (n + 1) + (i - 1) * (2 * n)) .+ (1:(2n)), :] - end - end - # test orthogonality of vecs and ∂vecs - for i in 1:howmany - prec = 4 * cond(A) * sqrt(eps(real(T))) - @test all(<(prec), real.(lvecs[i]' * ∂lvecs[i])) - @test all(<(prec), real.(rvecs[i]' * ∂rvecs[i])) - @test all(<(prec), abs.(lvecs[i]' * ∂lvecs[i] + rvecs[i]' * ∂rvecs[i])) - end - end - if T <: Complex - @testset "test warnings and info" begin - alg_rrule = Arnoldi(; tol=tol, krylovdim=4n, verbosity=-1) - (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, - howmany, :LR, alg; - alg_rrule=alg_rrule) - @test_logs pb((ZeroTangent(), im .* lvecs[1:2] .+ lvecs[2:-1:1], ZeroTangent(), - NoTangent())) - - alg_rrule = Arnoldi(; tol=tol, krylovdim=4n, verbosity=0) - (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, - howmany, :LR, alg; - alg_rrule=alg_rrule) - @test_logs (:warn,) pb((ZeroTangent(), - im .* lvecs[1:2] .+ lvecs[2:-1:1], - ZeroTangent(), - NoTangent())) - @test_logs (:warn,) pb((ZeroTangent(), lvecs[2:-1:1], - im .* rvecs[1:2] .+ rvecs[2:-1:1], - ZeroTangent(), - NoTangent())) - @test_logs pb((ZeroTangent(), lvecs[1:2] .+ lvecs[2:-1:1], - ZeroTangent(), - NoTangent())) - @test_logs (:warn,) pb((ZeroTangent(), - im .* lvecs[1:2] .+ lvecs[2:-1:1], - +im .* rvecs[1:2] + rvecs[2:-1:1], - NoTangent())) - @test_logs pb((ZeroTangent(), (1 + im) .* lvecs[1:2] .+ lvecs[2:-1:1], - (1 - im) .* rvecs[1:2] + rvecs[2:-1:1], - NoTangent())) - - alg_rrule = Arnoldi(; tol=tol, krylovdim=4n, verbosity=1) - (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, - howmany, :LR, alg; - alg_rrule=alg_rrule) - @test_logs (:warn,) (:info,) pb((ZeroTangent(), - im .* lvecs[1:2] .+ lvecs[2:-1:1], - ZeroTangent(), - NoTangent())) - @test_logs (:warn,) (:info,) pb((ZeroTangent(), lvecs[2:-1:1], - im .* rvecs[1:2] .+ rvecs[2:-1:1], - ZeroTangent(), - NoTangent())) - @test_logs (:info,) pb((ZeroTangent(), lvecs[1:2] .+ lvecs[2:-1:1], - ZeroTangent(), - NoTangent())) - @test_logs (:warn,) (:info,) pb((ZeroTangent(), - im .* lvecs[1:2] .+ lvecs[2:-1:1], - +im .* rvecs[1:2] + rvecs[2:-1:1], - NoTangent())) - @test_logs (:info,) pb((ZeroTangent(), (1 + im) .* lvecs[1:2] .+ lvecs[2:-1:1], - (1 - im) .* rvecs[1:2] + rvecs[2:-1:1], - NoTangent())) - - alg_rrule = GMRES(; tol=tol, krylovdim=3n, verbosity=-1) - (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, - howmany, :LR, alg; - alg_rrule=alg_rrule) - @test_logs pb((ZeroTangent(), im .* lvecs[1:2] .+ lvecs[2:-1:1], ZeroTangent(), - NoTangent())) - - alg_rrule = GMRES(; tol=tol, krylovdim=3n, verbosity=0) - (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, - howmany, :LR, alg; - alg_rrule=alg_rrule) - @test_logs (:warn,) (:warn,) pb((ZeroTangent(), - im .* lvecs[1:2] .+ - lvecs[2:-1:1], ZeroTangent(), - NoTangent())) - @test_logs (:warn,) (:warn,) pb((ZeroTangent(), lvecs[2:-1:1], - im .* rvecs[1:2] .+ - rvecs[2:-1:1], ZeroTangent(), - NoTangent())) - @test_logs pb((ZeroTangent(), lvecs[1:2] .+ lvecs[2:-1:1], - ZeroTangent(), - NoTangent())) - @test_logs (:warn,) (:warn,) pb((ZeroTangent(), - im .* lvecs[1:2] .+ - lvecs[2:-1:1], - +im .* rvecs[1:2] + - rvecs[2:-1:1], - NoTangent())) - @test_logs pb((ZeroTangent(), - (1 + im) .* lvecs[1:2] .+ lvecs[2:-1:1], - (1 - im) .* rvecs[1:2] + rvecs[2:-1:1], - NoTangent())) - - alg_rrule = GMRES(; tol=tol, krylovdim=3n, verbosity=1) - (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, - howmany, :LR, alg; - alg_rrule=alg_rrule) - @test_logs (:warn,) (:info,) (:warn,) (:info,) pb((ZeroTangent(), - im .* lvecs[1:2] .+ - lvecs[2:-1:1], ZeroTangent(), - NoTangent())) - @test_logs (:warn,) (:info,) (:warn,) (:info,) pb((ZeroTangent(), lvecs[2:-1:1], - im .* rvecs[1:2] .+ - rvecs[2:-1:1], ZeroTangent(), - NoTangent())) - @test_logs (:info,) (:info,) pb((ZeroTangent(), lvecs[1:2] .+ lvecs[2:-1:1], - ZeroTangent(), - NoTangent())) - @test_logs (:warn,) (:info,) (:warn,) (:info,) pb((ZeroTangent(), - im .* lvecs[1:2] .+ - lvecs[2:-1:1], - +im .* rvecs[1:2] + - rvecs[2:-1:1], - NoTangent())) - @test_logs (:info,) (:info,) pb((ZeroTangent(), - (1 + im) .* lvecs[1:2] .+ lvecs[2:-1:1], - (1 - im) .* rvecs[1:2] + rvecs[2:-1:1], - NoTangent())) - end - end -end -@timedtestset "Large svdsolve AD test with eltype=$T" for T in (Float64, ComplexF64) - which = :LR - A = rand(T, (N, N + n)) .- one(T) / 2 - A = I[1:N, 1:(N + n)] - (9 // 10) * A / maximum(svdvals(A)) - x = 2 * (rand(T, N) .- one(T) / 2) - x /= norm(x) - c = 2 * (rand(T, N) .- one(T) / 2) - d = 2 * (rand(T, N + n) .- one(T) / 2) - - howmany = 2 - tol = 2 * N^2 * eps(real(T)) - alg = GKL(; tol=tol, krylovdim=2n) - alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1) - alg_rrule2 = GMRES(; tol=tol, krylovdim=2n, verbosity=-1) - for alg_rrule in (alg_rrule1, alg_rrule2) - fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, lvecs, rvecs = build_fun_example(A, - x, - c, - d, - howmany, - alg, - alg_rrule) - - (JA, Jx, Jc, Jd) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec, - cvec, dvec) - (JA′, Jx′, Jc′, Jd′) = Zygote.jacobian(fun_example_ad, Avec, xvec, cvec, dvec) - @test JA ≈ JA′ - @test Jc ≈ Jc′ - @test Jd ≈ Jd′ - @test norm(Jx, Inf) < (T <: Complex ? 4n : n) * sqrt(eps(real(T))) - end -end -end diff --git a/test/ad/degenerateeigsolve.jl b/test/ad/degenerateeigsolve.jl new file mode 100644 index 0000000..af68c7c --- /dev/null +++ b/test/ad/degenerateeigsolve.jl @@ -0,0 +1,169 @@ +module DegenerateEigsolveAD + +using KrylovKit, LinearAlgebra +using Random, Test, TestExtras +using ChainRulesCore, ChainRulesTestUtils, Zygote, FiniteDifferences +Random.seed!(987654321) + +fdm = ChainRulesTestUtils._fdm +n = 10 +N = 30 + +function build_mat_example(A, B, C, x, alg, alg_rrule) + howmany = 1 + which = :LM + + Avec, A_fromvec = to_vec(A) + Bvec, B_fromvec = to_vec(B) + Cvec, C_fromvec = to_vec(C) + xvec, x_fromvec = to_vec(x) + + M = [zero(A) zero(A) C; A zero(A) zero(A); zero(A) B zero(A)] + vals, vecs, info = eigsolve(M, x, howmany, which, alg) + info.converged < howmany && @warn "eigsolve did not converge" + + function mat_example(Av, Bv, Cv, xv) + à = A_fromvec(Av) + B̃ = B_fromvec(Bv) + C̃ = C_fromvec(Cv) + x̃ = x_fromvec(xv) + M̃ = [zero(Ã) zero(Ã) C̃; à zero(Ã) zero(Ã); zero(Ã) B̃ zero(Ã)] + vals′, vecs′, info′ = eigsolve(M̃, x̃, howmany, which, alg; alg_rrule=alg_rrule) + info′.converged < howmany && @warn "eigsolve did not converge" + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + function mat_example_fun(Av, Bv, Cv, xv) + à = A_fromvec(Av) + B̃ = B_fromvec(Bv) + C̃ = C_fromvec(Cv) + x̃ = x_fromvec(xv) + M̃ = [zero(Ã) zero(Ã) C̃; à zero(Ã) zero(Ã); zero(Ã) B̃ zero(Ã)] + f = x -> M̃ * x + vals′, vecs′, info′ = eigsolve(f, x̃, howmany, which, alg; alg_rrule=alg_rrule) + info′.converged < howmany && @warn "eigsolve did not converge" + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + function mat_example_fd(Av, Bv, Cv, xv) + à = A_fromvec(Av) + B̃ = B_fromvec(Bv) + C̃ = C_fromvec(Cv) + x̃ = x_fromvec(xv) + M̃ = [zero(Ã) zero(Ã) C̃; à zero(Ã) zero(Ã); zero(Ã) B̃ zero(Ã)] + howmany′ = (eltype(Av) <: Complex ? 3 : 6) * howmany + vals′, vecs′, info′ = eigsolve(M̃, x̃, howmany′, which, alg; alg_rrule=alg_rrule) + _, i = findmin(abs.(vals′ .- vals[1])) + info′.converged < i && @warn "eigsolve did not converge" + d = dot(vecs[1], vecs′[i]) + @assert abs(d) > sqrt(eps(real(eltype(A)))) + phasefix = abs(d) / d + vecs′[i] = vecs′[i] * phasefix + catresults = vcat(vals′[i:i], vecs′[i:i]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + return mat_example, mat_example_fun, mat_example_fd, Avec, Bvec, Cvec, xvec, vals, + vecs +end + +@timedtestset "Degenerate eigsolve AD test with eltype=$T" for T in (Float64, ComplexF64) + n = 10 + N = 3n + + A = randn(T, n, n) + B = randn(T, n, n) + C = randn(T, n, n) + + M = [zeros(T, n, 2n) A; B zeros(T, n, 2n); zeros(T, n, n) C zeros(T, n, n)] + x = randn(T, N) + + tol = 2 * N^2 * eps(real(T)) + alg = Arnoldi(; tol=tol, krylovdim=2n) + alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1) + alg_rrule2 = GMRES(; tol=tol, krylovdim=2n, verbosity=-1) + mat_example1, mat_example_fun1, mat_example_fd, Avec, Bvec, Cvec, xvec, vals, vecs = build_mat_example(A, + B, + C, + x, + alg, + alg_rrule1) + mat_example2, mat_example_fun2, mat_example_fd, Avec, Bvec, Cvec, xvec, vals, vecs = build_mat_example(A, + B, + C, + x, + alg, + alg_rrule2) + (JA, JB, JC, Jx) = FiniteDifferences.jacobian(fdm, mat_example_fd, Avec, Bvec, + Cvec, xvec) + (JA1, JB1, JC1, Jx1) = Zygote.jacobian(mat_example1, Avec, Bvec, Cvec, xvec) + (JA2, JB2, JC2, Jx2) = Zygote.jacobian(mat_example_fun1, Avec, Bvec, Cvec, xvec) + (JA3, JB3, JC3, Jx3) = Zygote.jacobian(mat_example2, Avec, Bvec, Cvec, xvec) + (JA4, JB4, JC4, Jx4) = Zygote.jacobian(mat_example_fun2, Avec, Bvec, Cvec, xvec) + + @test isapprox(JA, JA1; rtol=N * sqrt(eps(real(T)))) + @test isapprox(JB, JB1; rtol=N * sqrt(eps(real(T)))) + @test isapprox(JC, JC1; rtol=N * sqrt(eps(real(T)))) + + @test all(isapprox.(JA1, JA2; atol=n * eps(real(T)))) + @test all(isapprox.(JB1, JB2; atol=n * eps(real(T)))) + @test all(isapprox.(JC1, JC2; atol=n * eps(real(T)))) + + @test all(isapprox.(JA1, JA3; atol=tol)) + @test all(isapprox.(JB1, JB3; atol=tol)) + @test all(isapprox.(JC1, JC3; atol=tol)) + + @test all(isapprox.(JA1, JA4; atol=tol)) + @test all(isapprox.(JB1, JB4; atol=tol)) + @test all(isapprox.(JC1, JC4; atol=tol)) + + @test norm(Jx, Inf) < N * sqrt(eps(real(T))) + @test all(iszero, Jx1) + @test all(iszero, Jx2) + @test all(iszero, Jx3) + @test all(iszero, Jx4) + + # some analysis + ∂valsA = complex.(JA1[1, :], JA1[N + 2, :]) + ∂valsB = complex.(JB1[1, :], JB1[N + 2, :]) + ∂valsC = complex.(JC1[1, :], JC1[N + 2, :]) + ∂vecsA = complex.(JA1[1 .+ (1:N), :], JA1[N + 2 .+ (1:N), :]) + ∂vecsB = complex.(JB1[1 .+ (1:N), :], JB1[N + 2 .+ (1:N), :]) + ∂vecsC = complex.(JC1[1 .+ (1:N), :], JC1[N + 2 .+ (1:N), :]) + if T <: Complex # test holomorphicity / Cauchy-Riemann equations + # for eigenvalues + @test real(∂valsA[1:2:(2n^2)]) ≈ +imag(∂valsA[2:2:(2n^2)]) + @test imag(∂valsA[1:2:(2n^2)]) ≈ -real(∂valsA[2:2:(2n^2)]) + @test real(∂valsB[1:2:(2n^2)]) ≈ +imag(∂valsB[2:2:(2n^2)]) + @test imag(∂valsB[1:2:(2n^2)]) ≈ -real(∂valsB[2:2:(2n^2)]) + @test real(∂valsC[1:2:(2n^2)]) ≈ +imag(∂valsC[2:2:(2n^2)]) + @test imag(∂valsC[1:2:(2n^2)]) ≈ -real(∂valsC[2:2:(2n^2)]) + # and for eigenvectors + @test real(∂vecsA[:, 1:2:(2n^2)]) ≈ +imag(∂vecsA[:, 2:2:(2n^2)]) + @test imag(∂vecsA[:, 1:2:(2n^2)]) ≈ -real(∂vecsA[:, 2:2:(2n^2)]) + @test real(∂vecsB[:, 1:2:(2n^2)]) ≈ +imag(∂vecsB[:, 2:2:(2n^2)]) + @test imag(∂vecsB[:, 1:2:(2n^2)]) ≈ -real(∂vecsB[:, 2:2:(2n^2)]) + @test real(∂vecsC[:, 1:2:(2n^2)]) ≈ +imag(∂vecsC[:, 2:2:(2n^2)]) + @test imag(∂vecsC[:, 1:2:(2n^2)]) ≈ -real(∂vecsC[:, 2:2:(2n^2)]) + end + # test orthogonality of vecs and ∂vecs + @test all(isapprox.(abs.(vecs[1]' * ∂vecsA), 0; atol=sqrt(eps(real(T))))) + @test all(isapprox.(abs.(vecs[1]' * ∂vecsB), 0; atol=sqrt(eps(real(T))))) + @test all(isapprox.(abs.(vecs[1]' * ∂vecsC), 0; atol=sqrt(eps(real(T))))) +end + +end diff --git a/test/ad/eigsolve.jl b/test/ad/eigsolve.jl new file mode 100644 index 0000000..e27e192 --- /dev/null +++ b/test/ad/eigsolve.jl @@ -0,0 +1,391 @@ +module EigsolveAD +using KrylovKit, LinearAlgebra +using Random, Test, TestExtras +using ChainRulesCore, ChainRulesTestUtils, Zygote, FiniteDifferences +Random.seed!(987654321) + +fdm = ChainRulesTestUtils._fdm +n = 10 +N = 30 + +function build_mat_example(A, x, howmany::Int, which, alg, alg_rrule) + Avec, A_fromvec = to_vec(A) + xvec, x_fromvec = to_vec(x) + + vals, vecs, info = eigsolve(A, x, howmany, which, alg) + info.converged < howmany && @warn "eigsolve did not converge" + if eltype(A) <: Real && length(vals) > howmany && + vals[howmany] == conj(vals[howmany + 1]) + howmany += 1 + end + + function mat_example(Av, xv) + à = A_fromvec(Av) + x̃ = x_fromvec(xv) + vals′, vecs′, info′ = eigsolve(Ã, x̃, howmany, which, alg; alg_rrule=alg_rrule) + info′.converged < howmany && @warn "eigsolve did not converge" + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + function mat_example_fun(Av, xv) + à = A_fromvec(Av) + x̃ = x_fromvec(xv) + f = x -> à * x + vals′, vecs′, info′ = eigsolve(f, x̃, howmany, which, alg; alg_rrule=alg_rrule) + info′.converged < howmany && @warn "eigsolve did not converge" + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + function mat_example_fd(Av, xv) + à = A_fromvec(Av) + x̃ = x_fromvec(xv) + vals′, vecs′, info′ = eigsolve(Ã, x̃, howmany, which, alg; alg_rrule=alg_rrule) + info′.converged < howmany && @warn "eigsolve did not converge" + for i in 1:howmany + d = dot(vecs[i], vecs′[i]) + @assert abs(d) > sqrt(eps(real(eltype(A)))) + phasefix = abs(d) / d + vecs′[i] = vecs′[i] * phasefix + end + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + return mat_example, mat_example_fun, mat_example_fd, Avec, xvec, vals, vecs, howmany +end + +function build_fun_example(A, x, c, d, howmany::Int, which, alg, alg_rrule) + Avec, matfromvec = to_vec(A) + xvec, vecfromvec = to_vec(x) + cvec, = to_vec(c) + dvec, = to_vec(d) + + vals, vecs, info = eigsolve(x, howmany, which, alg) do y + return A * y + c * dot(d, y) + end + info.converged < howmany && @warn "eigsolve did not converge" + if eltype(A) <: Real && length(vals) > howmany && + vals[howmany] == conj(vals[howmany + 1]) + howmany += 1 + end + + fun_example_ad = let howmany′ = howmany + function (Av, xv, cv, dv) + à = matfromvec(Av) + x̃ = vecfromvec(xv) + c̃ = vecfromvec(cv) + d̃ = vecfromvec(dv) + + vals′, vecs′, info′ = eigsolve(x̃, howmany′, which, alg; + alg_rrule=alg_rrule) do y + return à * y + c̃ * dot(d̃, y) + end + info′.converged < howmany′ && @warn "eigsolve did not converge" + catresults = vcat(vals′[1:howmany′], vecs′[1:howmany′]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + end + + fun_example_fd = let howmany′ = howmany + function (Av, xv, cv, dv) + à = matfromvec(Av) + x̃ = vecfromvec(xv) + c̃ = vecfromvec(cv) + d̃ = vecfromvec(dv) + + vals′, vecs′, info′ = eigsolve(x̃, howmany′, which, alg; + alg_rrule=alg_rrule) do y + return à * y + c̃ * dot(d̃, y) + end + info′.converged < howmany′ && @warn "eigsolve did not converge" + for i in 1:howmany′ + d = dot(vecs[i], vecs′[i]) + @assert abs(d) > sqrt(eps(real(eltype(A)))) + phasefix = abs(d) / d + vecs′[i] = vecs′[i] * phasefix + end + catresults = vcat(vals′[1:howmany′], vecs′[1:howmany′]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + end + + return fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany +end + +function build_hermitianfun_example(A, x, c, howmany::Int, which, alg, alg_rrule) + Avec, matfromvec = to_vec(A) + xvec, xvecfromvec = to_vec(x) + cvec, cvecfromvec = to_vec(c) + + vals, vecs, info = eigsolve(x, howmany, which, alg) do y + return Hermitian(A) * y + c * dot(c, y) + end + info.converged < howmany && @warn "eigsolve did not converge" + + function fun_example(Av, xv, cv) + à = matfromvec(Av) + x̃ = xvecfromvec(xv) + c̃ = cvecfromvec(cv) + + vals′, vecs′, info′ = eigsolve(x̃, howmany, which, alg; + alg_rrule=alg_rrule) do y + return Hermitian(Ã) * y + c̃ * dot(c̃, y) + end + info′.converged < howmany && @warn "eigsolve did not converge" + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + function fun_example_fd(Av, xv, cv) + à = matfromvec(Av) + x̃ = xvecfromvec(xv) + c̃ = cvecfromvec(cv) + + vals′, vecs′, info′ = eigsolve(x̃, howmany, which, alg; + alg_rrule=alg_rrule) do y + return Hermitian(Ã) * y + c̃ * dot(c̃, y) + end + info′.converged < howmany && @warn "eigsolve did not converge" + for i in 1:howmany + d = dot(vecs[i], vecs′[i]) + @assert abs(d) > sqrt(eps(real(eltype(A)))) + phasefix = abs(d) / d + vecs′[i] = vecs′[i] * phasefix + end + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + return fun_example, fun_example_fd, Avec, xvec, cvec, vals, vecs, howmany +end + +@timedtestset "Small eigsolve AD test for eltype=$T" for T in + (Float32, Float64, ComplexF32, + ComplexF64) + if T <: Complex + whichlist = (:LM, :SR, :LR, :SI, :LI) + else + whichlist = (:LM, :SR, :LR) + end + A = 2 * (rand(T, (n, n)) .- one(T) / 2) + x = 2 * (rand(T, n) .- one(T) / 2) + x /= norm(x) + + howmany = 3 + condA = cond(A) + tol = n * condA * (T <: Real ? eps(T) : 4 * eps(real(T))) + alg = Arnoldi(; tol=tol, krylovdim=n) + alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1) + alg_rrule2 = GMRES(; tol=tol, krylovdim=n + 1, verbosity=-1) + config = Zygote.ZygoteRuleConfig() + @testset for which in whichlist + for alg_rrule in (alg_rrule1, alg_rrule2) + # unfortunately, rrule does not seem type stable for function arguments, because the + # `rrule_via_ad` call does not produce type stable `rrule`s for the function + (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, + which, alg; alg_rrule=alg_rrule) + # NOTE: the following is not necessary here, as it is corrected for in the `eigsolve` rrule + # if length(vals) > howmany && vals[howmany] == conj(vals[howmany + 1]) + # howmany += 1 + # end + @constinferred pb((ZeroTangent(), ZeroTangent(), NoTangent())) + @constinferred pb((randn(T, howmany), ZeroTangent(), NoTangent())) + @constinferred pb((randn(T, howmany), [randn(T, n)], NoTangent())) + @constinferred pb((randn(T, howmany), [randn(T, n) for _ in 1:howmany], + NoTangent())) + end + + for alg_rrule in (alg_rrule1, alg_rrule2) + mat_example, mat_example_fun, mat_example_fd, Avec, xvec, vals, vecs, howmany = build_mat_example(A, + x, + howmany, + which, + alg, + alg_rrule) + + (JA, Jx) = FiniteDifferences.jacobian(fdm, mat_example_fd, Avec, xvec) + (JA1, Jx1) = Zygote.jacobian(mat_example, Avec, xvec) + (JA2, Jx2) = Zygote.jacobian(mat_example_fun, Avec, xvec) + + # finite difference comparison using some kind of tolerance heuristic + @test isapprox(JA, JA1; rtol=condA * sqrt(eps(real(T)))) + @test all(isapprox.(JA1, JA2; atol=n * eps(real(T)))) + @test norm(Jx, Inf) < condA * sqrt(eps(real(T))) + @test all(iszero, Jx1) + @test all(iszero, Jx2) + + # some analysis + ∂vals = complex.(JA1[1:howmany, :], JA1[howmany * (n + 1) .+ (1:howmany), :]) + ∂vecs = map(1:howmany) do i + return complex.(JA1[(howmany + (i - 1) * n) .+ (1:n), :], + JA1[(howmany * (n + 2) + (i - 1) * n) .+ (1:n), :]) + end + if eltype(A) <: Complex # test holomorphicity / Cauchy-Riemann equations + # for eigenvalues + @test real(∂vals[:, 1:2:(2n^2)]) ≈ +imag(∂vals[:, 2:2:(2n^2)]) + @test imag(∂vals[:, 1:2:(2n^2)]) ≈ -real(∂vals[:, 2:2:(2n^2)]) + # and for eigenvectors + for i in 1:howmany + @test real(∂vecs[i][:, 1:2:(2n^2)]) ≈ +imag(∂vecs[i][:, 2:2:(2n^2)]) + @test imag(∂vecs[i][:, 1:2:(2n^2)]) ≈ -real(∂vecs[i][:, 2:2:(2n^2)]) + end + end + # test orthogonality of vecs and ∂vecs + for i in 1:howmany + @test all(isapprox.(abs.(vecs[i]' * ∂vecs[i]), 0; atol=sqrt(eps(real(T))))) + end + end + end + + if T <: Complex + @testset "test warnings and info" begin + alg_rrule = Arnoldi(; tol=tol, krylovdim=n, verbosity=-1) + (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, + :LR, alg; alg_rrule=alg_rrule) + @test_logs pb((ZeroTangent(), im .* vecs[1:2] .+ vecs[2:-1:1], NoTangent())) + + alg_rrule = Arnoldi(; tol=tol, krylovdim=n, verbosity=0) + (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, + :LR, alg; alg_rrule=alg_rrule) + @test_logs (:warn,) pb((ZeroTangent(), im .* vecs[1:2] .+ vecs[2:-1:1], + NoTangent())) + pbs = @test_logs pb((ZeroTangent(), vecs[1:2], NoTangent())) + @test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T))) + + alg_rrule = Arnoldi(; tol=tol, krylovdim=n, verbosity=1) + (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, + :LR, alg; alg_rrule=alg_rrule) + @test_logs (:warn,) (:info,) pb((ZeroTangent(), im .* vecs[1:2] .+ vecs[2:-1:1], + NoTangent())) + pbs = @test_logs (:info,) pb((ZeroTangent(), vecs[1:2], NoTangent())) + @test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T))) + + alg_rrule = GMRES(; tol=tol, krylovdim=n, verbosity=-1) + (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, + :LR, alg; alg_rrule=alg_rrule) + @test_logs pb((ZeroTangent(), im .* vecs[1:2] .+ vecs[2:-1:1], NoTangent())) + + alg_rrule = GMRES(; tol=tol, krylovdim=n, verbosity=0) + (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, + :LR, alg; alg_rrule=alg_rrule) + @test_logs (:warn,) (:warn,) pb((ZeroTangent(), + im .* vecs[1:2] .+ + vecs[2:-1:1], + NoTangent())) + pbs = @test_logs pb((ZeroTangent(), vecs[1:2], NoTangent())) + @test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T))) + + alg_rrule = GMRES(; tol=tol, krylovdim=n, verbosity=1) + (vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany, + :LR, alg; alg_rrule=alg_rrule) + @test_logs (:warn,) (:info,) (:warn,) (:info,) pb((ZeroTangent(), + im .* vecs[1:2] .+ + vecs[2:-1:1], + NoTangent())) + pbs = @test_logs (:info,) (:info,) pb((ZeroTangent(), vecs[1:2], NoTangent())) + @test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T))) + end + end +end +@timedtestset "Large eigsolve AD test with eltype=$T" for T in (Float64, ComplexF64) + if T <: Complex + whichlist = (:LM, :SI) + else + whichlist = (:LM, :SR) + end + @testset for which in whichlist + A = rand(T, (N, N)) .- one(T) / 2 + A = I - (9 // 10) * A / maximum(abs, eigvals(A)) + x = 2 * (rand(T, N) .- one(T) / 2) + x /= norm(x) + c = 2 * (rand(T, N) .- one(T) / 2) + d = 2 * (rand(T, N) .- one(T) / 2) + + howmany = 2 + tol = 2 * N^2 * eps(real(T)) + alg = Arnoldi(; tol=tol, krylovdim=2n) + alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1) + alg_rrule2 = GMRES(; tol=tol, krylovdim=2n, verbosity=-1) + @testset for alg_rrule in (alg_rrule1, alg_rrule2) + fun_example, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany = build_fun_example(A, + x, + c, + d, + howmany, + which, + alg, + alg_rrule) + + (JA, Jx, Jc, Jd) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec, + cvec, dvec) + (JA′, Jx′, Jc′, Jd′) = Zygote.jacobian(fun_example, Avec, xvec, cvec, dvec) + @test JA ≈ JA′ + @test Jc ≈ Jc′ + @test Jd ≈ Jd′ + end + end +end +@timedtestset "Large Hermitian eigsolve AD test with eltype=$T" for T in + (Float64, ComplexF64) + whichlist = (:LR, :SR) + @testset for which in whichlist + A = rand(T, (N, N)) .- one(T) / 2 + A = I - (9 // 10) * A / maximum(abs, eigvals(A)) + x = 2 * (rand(T, N) .- one(T) / 2) + x /= norm(x) + c = 2 * (rand(T, N) .- one(T) / 2) + + howmany = 2 + tol = 2 * N^2 * eps(real(T)) + alg = Lanczos(; tol=tol, krylovdim=2n) + alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1) + alg_rrule2 = GMRES(; tol=tol, krylovdim=2n, verbosity=-1) + @testset for alg_rrule in (alg_rrule1, alg_rrule2) + fun_example, fun_example_fd, Avec, xvec, cvec, vals, vecs, howmany = build_hermitianfun_example(A, + x, + c, + howmany, + which, + alg, + alg_rrule) + + (JA, Jx, Jc) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec, + cvec) + (JA′, Jx′, Jc′) = Zygote.jacobian(fun_example, Avec, xvec, cvec) + @test JA ≈ JA′ + @test Jc ≈ Jc′ + end + end +end + +end diff --git a/test/ad/linsolve.jl b/test/ad/linsolve.jl new file mode 100644 index 0000000..b3220cd --- /dev/null +++ b/test/ad/linsolve.jl @@ -0,0 +1,129 @@ +module LinsolveAD +using KrylovKit, LinearAlgebra +using Random, Test, TestExtras +using ChainRulesCore, ChainRulesTestUtils, Zygote, FiniteDifferences + +fdm = ChainRulesTestUtils._fdm +n = 10 +N = 30 + +function build_mat_example(A, b, x, alg, alg_rrule) + Avec, A_fromvec = to_vec(A) + bvec, b_fromvec = to_vec(b) + xvec, x_fromvec = to_vec(x) + T = eltype(A) + + function mat_example(Av, bv, xv) + à = A_fromvec(Av) + b̃ = b_fromvec(bv) + x̃ = x_fromvec(xv) + x, info = linsolve(Ã, b̃, x̃, alg; alg_rrule=alg_rrule) + if info.converged == 0 + @warn "linsolve did not converge:" + println("normres = ", info.normres) + end + xv, = to_vec(x) + return xv + end + function mat_example_fun(Av, bv, xv) + à = A_fromvec(Av) + b̃ = b_fromvec(bv) + x̃ = x_fromvec(xv) + f = x -> à * x + x, info = linsolve(f, b̃, x̃, alg; alg_rrule=alg_rrule) + if info.converged == 0 + @warn "linsolve did not converge:" + println("normres = ", info.normres) + end + xv, = to_vec(x) + return xv + end + return mat_example, mat_example_fun, Avec, bvec, xvec +end + +function build_fun_example(A, b, c, d, e, f, alg, alg_rrule) + Avec, matfromvec = to_vec(A) + bvec, vecfromvec = to_vec(b) + cvec, = to_vec(c) + dvec, = to_vec(d) + evec, scalarfromvec = to_vec(e) + fvec, = to_vec(f) + + function fun_example(Av, bv, cv, dv, ev, fv) + à = matfromvec(Av) + b̃ = vecfromvec(bv) + c̃ = vecfromvec(cv) + d̃ = vecfromvec(dv) + ẽ = scalarfromvec(ev) + f̃ = scalarfromvec(fv) + + x, info = linsolve(b̃, zero(b̃), alg, ẽ, f̃; alg_rrule=alg_rrule) do y + return à * y + c̃ * dot(d̃, y) + end + # info.converged > 0 || @warn "not converged" + xv, = to_vec(x) + return xv + end + return fun_example, Avec, bvec, cvec, dvec, evec, fvec +end + +@testset "Small linsolve AD test with eltype=$T" for T in (Float32, Float64, ComplexF32, + ComplexF64) + A = 2 * (rand(T, (n, n)) .- one(T) / 2) + b = 2 * (rand(T, n) .- one(T) / 2) + b /= norm(b) + x = 2 * (rand(T, n) .- one(T) / 2) + + condA = cond(A) + tol = condA * (T <: Real ? eps(T) : 4 * eps(real(T))) + alg = GMRES(; tol=tol, krylovdim=n, maxiter=1) + + config = Zygote.ZygoteRuleConfig() + _, pb = ChainRulesCore.rrule(config, linsolve, A, b, x, alg, 0, 1; alg_rrule=alg) + @constinferred pb((ZeroTangent(), NoTangent())) + @constinferred pb((rand(T, n), NoTangent())) + + mat_example, mat_example_fun, Avec, bvec, xvec = build_mat_example(A, b, x, alg, alg) + (JA, Jb, Jx) = FiniteDifferences.jacobian(fdm, mat_example, Avec, bvec, xvec) + (JA1, Jb1, Jx1) = Zygote.jacobian(mat_example, Avec, bvec, xvec) + (JA2, Jb2, Jx2) = Zygote.jacobian(mat_example_fun, Avec, bvec, xvec) + + @test isapprox(JA, JA1; rtol=condA * sqrt(eps(real(T)))) + @test all(isapprox.(JA1, JA2; atol=n * eps(real(T)))) + # factor 2 is minimally necessary for complex case, but 3 is more robust + @test norm(Jx, Inf) < condA * sqrt(eps(real(T))) + @test all(iszero, Jx1) +end + +@testset "Large linsolve AD test with eltype=$T" for T in (Float64, ComplexF64) + A = rand(T, (N, N)) .- one(T) / 2 + A = I - (9 // 10) * A / maximum(abs, eigvals(A)) + b = 2 * (rand(T, N) .- one(T) / 2) + c = 2 * (rand(T, N) .- one(T) / 2) + d = 2 * (rand(T, N) .- one(T) / 2) + e = rand(T) + f = rand(T) + + # mix algorithms] + tol = N^2 * eps(real(T)) + alg1 = GMRES(; tol=tol, krylovdim=20) + alg2 = BiCGStab(; tol=tol, maxiter=100) # BiCGStab seems to require slightly smaller tolerance for tests to work + for (alg, alg_rrule) in ((alg1, alg2), (alg2, alg1)) + fun_example, Avec, bvec, cvec, dvec, evec, fvec = build_fun_example(A, b, c, d, e, + f, alg, + alg_rrule) + + (JA, Jb, Jc, Jd, Je, Jf) = FiniteDifferences.jacobian(fdm, fun_example, + Avec, bvec, cvec, dvec, evec, + fvec) + (JA′, Jb′, Jc′, Jd′, Je′, Jf′) = Zygote.jacobian(fun_example, Avec, bvec, cvec, + dvec, evec, fvec) + @test JA ≈ JA′ + @test Jb ≈ Jb′ + @test Jc ≈ Jc′ + @test Jd ≈ Jd′ + @test Je ≈ Je′ + @test Jf ≈ Jf′ + end +end +end diff --git a/test/ad/svdsolve.jl b/test/ad/svdsolve.jl new file mode 100644 index 0000000..08d6366 --- /dev/null +++ b/test/ad/svdsolve.jl @@ -0,0 +1,368 @@ +module SvdsolveAD +using KrylovKit, LinearAlgebra +using Random, Test, TestExtras +using ChainRulesCore, ChainRulesTestUtils, Zygote, FiniteDifferences +Random.seed!(123456789) + +fdm = ChainRulesTestUtils._fdm +n = 10 +N = 30 + +function build_mat_example(A, x, howmany::Int, alg, alg_rrule) + Avec, A_fromvec = to_vec(A) + xvec, x_fromvec = to_vec(x) + + vals, lvecs, rvecs, info = svdsolve(A, x, howmany, :LR, alg) + info.converged < howmany && @warn "svdsolve did not converge" + + function mat_example_mat(Av, xv) + à = A_fromvec(Av) + x̃ = x_fromvec(xv) + vals′, lvecs′, rvecs′, info′ = svdsolve(Ã, x̃, howmany, :LR, alg; + alg_rrule=alg_rrule) + info′.converged < howmany && @warn "svdsolve did not converge" + catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + function mat_example_fval(Av, xv) + à = A_fromvec(Av) + x̃ = x_fromvec(xv) + f = (x, adj::Val) -> (adj isa Val{true}) ? adjoint(Ã) * x : à * x + vals′, lvecs′, rvecs′, info′ = svdsolve(f, x̃, howmany, :LR, alg; + alg_rrule=alg_rrule) + info′.converged < howmany && @warn "svdsolve did not converge" + catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + function mat_example_ftuple(Av, xv) + à = A_fromvec(Av) + x̃ = x_fromvec(xv) + (f, fᴴ) = (x -> à * x, x -> adjoint(Ã) * x) + vals′, lvecs′, rvecs′, info′ = svdsolve((f, fᴴ), x̃, howmany, :LR, alg; + alg_rrule=alg_rrule) + info′.converged < howmany && @warn "svdsolve did not converge" + catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + function mat_example_fd(Av, xv) + à = A_fromvec(Av) + x̃ = x_fromvec(xv) + vals′, lvecs′, rvecs′, info′ = svdsolve(Ã, x̃, howmany, :LR, alg; + alg_rrule=alg_rrule) + info′.converged < howmany && @warn "svdsolve did not converge" + for i in 1:howmany + dl = dot(lvecs[i], lvecs′[i]) + dr = dot(rvecs[i], rvecs′[i]) + @assert abs(dl) > sqrt(eps(real(eltype(A)))) + @assert abs(dr) > sqrt(eps(real(eltype(A)))) + phasefix = sqrt(abs(dl * dr) / (dl * dr)) + lvecs′[i] = lvecs′[i] * phasefix + rvecs′[i] = rvecs′[i] * phasefix + end + catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + return mat_example_mat, mat_example_ftuple, mat_example_fval, mat_example_fd, Avec, + xvec, vals, lvecs, rvecs +end + +function build_fun_example(A, x, c, d, howmany::Int, alg, alg_rrule) + Avec, matfromvec = to_vec(A) + xvec, xvecfromvec = to_vec(x) + cvec, cvecfromvec = to_vec(c) + dvec, dvecfromvec = to_vec(d) + + f = y -> A * y + c * dot(d, y) + fᴴ = y -> adjoint(A) * y + d * dot(c, y) + vals, lvecs, rvecs, info = svdsolve((f, fᴴ), x, howmany, :LR, alg) + info.converged < howmany && @warn "svdsolve did not converge" + + function fun_example_ad(Av, xv, cv, dv) + à = matfromvec(Av) + x̃ = xvecfromvec(xv) + c̃ = cvecfromvec(cv) + d̃ = dvecfromvec(dv) + + f = y -> à * y + c̃ * dot(d̃, y) + fᴴ = y -> adjoint(Ã) * y + d̃ * dot(c̃, y) + vals′, lvecs′, rvecs′, info′ = svdsolve((f, fᴴ), x̃, howmany, :LR, alg; + alg_rrule=alg_rrule) + info′.converged < howmany && @warn "svdsolve did not converge" + catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + function fun_example_fd(Av, xv, cv, dv) + à = matfromvec(Av) + x̃ = xvecfromvec(xv) + c̃ = cvecfromvec(cv) + d̃ = dvecfromvec(dv) + + f = y -> à * y + c̃ * dot(d̃, y) + fᴴ = y -> adjoint(Ã) * y + d̃ * dot(c̃, y) + vals′, lvecs′, rvecs′, info′ = svdsolve((f, fᴴ), x̃, howmany, :LR, alg; + alg_rrule=alg_rrule) + info′.converged < howmany && @warn "svdsolve did not converge" + for i in 1:howmany + dl = dot(lvecs[i], lvecs′[i]) + dr = dot(rvecs[i], rvecs′[i]) + @assert abs(dl) > sqrt(eps(real(eltype(A)))) + @assert abs(dr) > sqrt(eps(real(eltype(A)))) + phasefix = sqrt(abs(dl * dr) / (dl * dr)) + lvecs′[i] = lvecs′[i] * phasefix + rvecs′[i] = rvecs′[i] * phasefix + end + catresults = vcat(vals′[1:howmany], lvecs′[1:howmany]..., rvecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + return fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, lvecs, rvecs +end + +@timedtestset "Small svdsolve AD test with eltype=$T" for T in + (Float32, Float64, ComplexF32, + ComplexF64) + A = 2 * (rand(T, (n, 2 * n)) .- one(T) / 2) + x = 2 * (rand(T, n) .- one(T) / 2) + x /= norm(x) + condA = cond(A) + + howmany = 3 + tol = 3 * n * condA * (T <: Real ? eps(T) : 4 * eps(real(T))) + alg = GKL(; krylovdim=2n, tol=tol) + alg_rrule1 = Arnoldi(; tol=tol, krylovdim=4n, verbosity=-1) + alg_rrule2 = GMRES(; tol=tol, krylovdim=3n, verbosity=-1) + config = Zygote.ZygoteRuleConfig() + for alg_rrule in (alg_rrule1, alg_rrule2) + # unfortunately, rrule does not seem type stable for function arguments, because the + # `rrule_via_ad` call does not produce type stable `rrule`s for the function + _, pb = ChainRulesCore.rrule(config, svdsolve, A, x, howmany, :LR, alg; + alg_rrule=alg_rrule) + @constinferred pb((ZeroTangent(), ZeroTangent(), ZeroTangent(), NoTangent())) + @constinferred pb((randn(real(T), howmany), ZeroTangent(), ZeroTangent(), + NoTangent())) + @constinferred pb((randn(real(T), howmany), [randn(T, n)], ZeroTangent(), + NoTangent())) + @constinferred pb((randn(real(T), howmany), [randn(T, n) for _ in 1:howmany], + [randn(T, 2 * n) for _ in 1:howmany], NoTangent())) + end + for alg_rrule in (alg_rrule1, alg_rrule2) + (mat_example_mat, mat_example_ftuple, mat_example_fval, mat_example_fd, + Avec, xvec, vals, lvecs, rvecs) = build_mat_example(A, x, howmany, alg, alg_rrule) + + (JA, Jx) = FiniteDifferences.jacobian(fdm, mat_example_fd, Avec, xvec) + (JA1, Jx1) = Zygote.jacobian(mat_example_mat, Avec, xvec) + (JA2, Jx2) = Zygote.jacobian(mat_example_fval, Avec, xvec) + (JA3, Jx3) = Zygote.jacobian(mat_example_ftuple, Avec, xvec) + + # finite difference comparison using some kind of tolerance heuristic + @test isapprox(JA, JA1; rtol=3 * n * n * condA * sqrt(eps(real(T)))) + @test all(isapprox.(JA1, JA2; atol=n * eps(real(T)))) + @test all(isapprox.(JA1, JA3; atol=n * eps(real(T)))) + @test norm(Jx, Inf) < 5 * condA * sqrt(eps(real(T))) + @test all(iszero, Jx1) + @test all(iszero, Jx2) + @test all(iszero, Jx3) + + # some analysis + if eltype(A) <: Complex # test holomorphicity / Cauchy-Riemann equations + ∂vals = complex.(JA1[1:howmany, :], + JA1[howmany * (3 * n + 1) .+ (1:howmany), :]) + ∂lvecs = map(1:howmany) do i + return complex.(JA1[(howmany + (i - 1) * n) .+ (1:n), :], + JA1[(howmany * (3 * n + 2) + (i - 1) * n) .+ (1:n), :]) + end + ∂rvecs = map(1:howmany) do i + return complex.(JA1[(howmany * (n + 1) + (i - 1) * (2 * n)) .+ (1:(2n)), :], + JA1[(howmany * (4 * n + 2) + (i - 1) * 2n) .+ (1:(2n)), :]) + end + else + ∂vals = JA1[1:howmany, :] + ∂lvecs = map(1:howmany) do i + return JA1[(howmany + (i - 1) * n) .+ (1:n), :] + end + ∂rvecs = map(1:howmany) do i + return JA1[(howmany * (n + 1) + (i - 1) * (2 * n)) .+ (1:(2n)), :] + end + end + # test orthogonality of vecs and ∂vecs + for i in 1:howmany + prec = 4 * cond(A) * sqrt(eps(real(T))) + @test all(<(prec), real.(lvecs[i]' * ∂lvecs[i])) + @test all(<(prec), real.(rvecs[i]' * ∂rvecs[i])) + @test all(<(prec), abs.(lvecs[i]' * ∂lvecs[i] + rvecs[i]' * ∂rvecs[i])) + end + end + if T <: Complex + @testset "test warnings and info" begin + alg_rrule = Arnoldi(; tol=tol, krylovdim=4n, verbosity=-1) + (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, + howmany, :LR, alg; + alg_rrule=alg_rrule) + @test_logs pb((ZeroTangent(), im .* lvecs[1:2] .+ lvecs[2:-1:1], ZeroTangent(), + NoTangent())) + + alg_rrule = Arnoldi(; tol=tol, krylovdim=4n, verbosity=0) + (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, + howmany, :LR, alg; + alg_rrule=alg_rrule) + @test_logs (:warn,) pb((ZeroTangent(), + im .* lvecs[1:2] .+ lvecs[2:-1:1], + ZeroTangent(), + NoTangent())) + @test_logs (:warn,) pb((ZeroTangent(), lvecs[2:-1:1], + im .* rvecs[1:2] .+ rvecs[2:-1:1], + ZeroTangent(), + NoTangent())) + @test_logs pb((ZeroTangent(), lvecs[1:2] .+ lvecs[2:-1:1], + ZeroTangent(), + NoTangent())) + @test_logs (:warn,) pb((ZeroTangent(), + im .* lvecs[1:2] .+ lvecs[2:-1:1], + +im .* rvecs[1:2] + rvecs[2:-1:1], + NoTangent())) + @test_logs pb((ZeroTangent(), (1 + im) .* lvecs[1:2] .+ lvecs[2:-1:1], + (1 - im) .* rvecs[1:2] + rvecs[2:-1:1], + NoTangent())) + + alg_rrule = Arnoldi(; tol=tol, krylovdim=4n, verbosity=1) + (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, + howmany, :LR, alg; + alg_rrule=alg_rrule) + @test_logs (:warn,) (:info,) pb((ZeroTangent(), + im .* lvecs[1:2] .+ lvecs[2:-1:1], + ZeroTangent(), + NoTangent())) + @test_logs (:warn,) (:info,) pb((ZeroTangent(), lvecs[2:-1:1], + im .* rvecs[1:2] .+ rvecs[2:-1:1], + ZeroTangent(), + NoTangent())) + @test_logs (:info,) pb((ZeroTangent(), lvecs[1:2] .+ lvecs[2:-1:1], + ZeroTangent(), + NoTangent())) + @test_logs (:warn,) (:info,) pb((ZeroTangent(), + im .* lvecs[1:2] .+ lvecs[2:-1:1], + +im .* rvecs[1:2] + rvecs[2:-1:1], + NoTangent())) + @test_logs (:info,) pb((ZeroTangent(), (1 + im) .* lvecs[1:2] .+ lvecs[2:-1:1], + (1 - im) .* rvecs[1:2] + rvecs[2:-1:1], + NoTangent())) + + alg_rrule = GMRES(; tol=tol, krylovdim=3n, verbosity=-1) + (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, + howmany, :LR, alg; + alg_rrule=alg_rrule) + @test_logs pb((ZeroTangent(), im .* lvecs[1:2] .+ lvecs[2:-1:1], ZeroTangent(), + NoTangent())) + + alg_rrule = GMRES(; tol=tol, krylovdim=3n, verbosity=0) + (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, + howmany, :LR, alg; + alg_rrule=alg_rrule) + @test_logs (:warn,) (:warn,) pb((ZeroTangent(), + im .* lvecs[1:2] .+ + lvecs[2:-1:1], ZeroTangent(), + NoTangent())) + @test_logs (:warn,) (:warn,) pb((ZeroTangent(), lvecs[2:-1:1], + im .* rvecs[1:2] .+ + rvecs[2:-1:1], ZeroTangent(), + NoTangent())) + @test_logs pb((ZeroTangent(), lvecs[1:2] .+ lvecs[2:-1:1], + ZeroTangent(), + NoTangent())) + @test_logs (:warn,) (:warn,) pb((ZeroTangent(), + im .* lvecs[1:2] .+ + lvecs[2:-1:1], + +im .* rvecs[1:2] + + rvecs[2:-1:1], + NoTangent())) + @test_logs pb((ZeroTangent(), + (1 + im) .* lvecs[1:2] .+ lvecs[2:-1:1], + (1 - im) .* rvecs[1:2] + rvecs[2:-1:1], + NoTangent())) + + alg_rrule = GMRES(; tol=tol, krylovdim=3n, verbosity=1) + (vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x, + howmany, :LR, alg; + alg_rrule=alg_rrule) + @test_logs (:warn,) (:info,) (:warn,) (:info,) pb((ZeroTangent(), + im .* lvecs[1:2] .+ + lvecs[2:-1:1], ZeroTangent(), + NoTangent())) + @test_logs (:warn,) (:info,) (:warn,) (:info,) pb((ZeroTangent(), lvecs[2:-1:1], + im .* rvecs[1:2] .+ + rvecs[2:-1:1], ZeroTangent(), + NoTangent())) + @test_logs (:info,) (:info,) pb((ZeroTangent(), lvecs[1:2] .+ lvecs[2:-1:1], + ZeroTangent(), + NoTangent())) + @test_logs (:warn,) (:info,) (:warn,) (:info,) pb((ZeroTangent(), + im .* lvecs[1:2] .+ + lvecs[2:-1:1], + +im .* rvecs[1:2] + + rvecs[2:-1:1], + NoTangent())) + @test_logs (:info,) (:info,) pb((ZeroTangent(), + (1 + im) .* lvecs[1:2] .+ lvecs[2:-1:1], + (1 - im) .* rvecs[1:2] + rvecs[2:-1:1], + NoTangent())) + end + end +end +@timedtestset "Large svdsolve AD test with eltype=$T" for T in (Float64, ComplexF64) + which = :LR + A = rand(T, (N, N + n)) .- one(T) / 2 + A = I[1:N, 1:(N + n)] - (9 // 10) * A / maximum(svdvals(A)) + x = 2 * (rand(T, N) .- one(T) / 2) + x /= norm(x) + c = 2 * (rand(T, N) .- one(T) / 2) + d = 2 * (rand(T, N + n) .- one(T) / 2) + + howmany = 2 + tol = 2 * N^2 * eps(real(T)) + alg = GKL(; tol=tol, krylovdim=2n) + alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1) + alg_rrule2 = GMRES(; tol=tol, krylovdim=2n, verbosity=-1) + for alg_rrule in (alg_rrule1, alg_rrule2) + fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, lvecs, rvecs = build_fun_example(A, + x, + c, + d, + howmany, + alg, + alg_rrule) + + (JA, Jx, Jc, Jd) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec, + cvec, dvec) + (JA′, Jx′, Jc′, Jd′) = Zygote.jacobian(fun_example_ad, Avec, xvec, cvec, dvec) + @test JA ≈ JA′ + @test Jc ≈ Jc′ + @test Jd ≈ Jd′ + @test norm(Jx, Inf) < (T <: Complex ? 4n : n) * sqrt(eps(real(T))) + end +end +end diff --git a/test/runtests.jl b/test/runtests.jl index aa402e5..ad26dce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,7 +38,10 @@ include("expintegrator.jl") include("linalg.jl") include("nestedtuple.jl") -include("ad.jl") +include("ad/linsolve.jl") +include("ad/eigsolve.jl") +include("ad/degenerateeigsolve.jl") +include("ad/svdsolve.jl") t = time() - t println("Tests finished in $t seconds")