Skip to content

Commit

Permalink
fix arnoldi ad rule for degenerate eigsolve (#99)
Browse files Browse the repository at this point in the history
* fix arnoldi ad rule for degenerate eigsolve

* another fix attempt

* the problem is with finite difference

* cleanup
  • Loading branch information
Jutho authored Nov 9, 2024
1 parent e6b1fee commit e24dfe4
Show file tree
Hide file tree
Showing 7 changed files with 1,084 additions and 903 deletions.
35 changes: 23 additions & 12 deletions ext/KrylovKitChainRulesCoreExt/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,15 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
# [(A * (1-P) + shift * P) -ΔV; 0 Λ], where eᵢ is the ith unit vector. We will need
# to renormalise the eigenvectors to have exactly eᵢ as second component. We use
# (0, e₁ + e₂ + ... + eₙ) as the initial guess for the eigenvalue problem.

W₀ = (zerovector(vecs[1]), one.(vals))
P = orthogonalprojector(vecs, n, Gc)
# TODO: is `realeigsolve` every used here, as there is a separate `alg_primal::Lanczos` method below
solver = (T <: Real) ? KrylovKit.realeigsolve : KrylovKit.eigsolve # for `eigsolve`, `T` will always be a Complex subtype`
rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg, shift = shift
solver(W₀, n, reverse_which(which), alg_rrule) do (w, x)
rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg, shift = shift,
eigsort = EigSorter(v -> minimum(DistanceTo(conj(v)), vals))

solver(W₀, n, eigsort, alg_rrule) do (w, x)
w₀ = P(w)
w′ = KrylovKit.apply(fᴴ, add(w, w₀, -1))
if !iszero(shift)
Expand All @@ -268,20 +272,25 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
tol = alg_rrule.tol
Q = orthogonalcomplementprojector(vecs, n, Gc)
for i in 1:n
w, x = Ws[i]
_, ic = findmax(abs, x)
factor = 1 / x[ic]
x[ic] = zero(x[ic])
d, ic = findmin(DistanceTo(conj(vals[i])), rvals)
w, x = Ws[ic]
factor = 1 / x[i]
x[i] = zero(x[i])
if alg_rrule.verbosity >= 0
error = max(norm(x, Inf), abs(rvals[i] - conj(vals[ic])))
error > 5 * tol &&
@warn "`eigsolve` cotangent linear problem ($ic) returns unexpected result: error = $error"
error = max(norm(x, Inf), abs(rvals[ic] - conj(vals[i])))
error > 10 * tol &&
@warn "`eigsolve` cotangent linear problem ($i) returns unexpected result: error = $error"
end
ws[ic] = VectorInterface.add!!(zs[ic], Q(w), -factor)
ws[i] = VectorInterface.add!!(zs[i], Q(w), -factor)
end
return ws
end

struct DistanceTo{T}
x::T
end
(d::DistanceTo)(y) = norm(y - d.x)

# several simplications happen in the case of a Hermitian eigenvalue problem
function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ,
alg_primal::Lanczos, alg_rrule::Arnoldi)
Expand Down Expand Up @@ -342,8 +351,10 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
W₀ = (zerovector(vecs[1]), one.(vals))
P = orthogonalprojector(vecs, n)
solver = (T <: Real) ? KrylovKit.realeigsolve : KrylovKit.eigsolve
rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg, shift = shift
solver(W₀, n, reverse_which(which), alg_rrule) do (w, x)
rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg, shift = shift,
eigsort = EigSorter(v -> minimum(DistanceTo(conj(v)), vals))

solver(W₀, n, eigsort, alg_rrule) do (w, x)
w₀ = P(w)
w′ = KrylovKit.apply(fᴴ, add(w, w₀, -1))
if !iszero(shift)
Expand Down
Loading

0 comments on commit e24dfe4

Please sign in to comment.