Skip to content

Commit

Permalink
fix bug in eigsolve rrule (#96)
Browse files Browse the repository at this point in the history
* some patches

* deal with a bug

* correct filename
  • Loading branch information
XingyuZhang2018 authored Nov 5, 2024
1 parent 8bccac8 commit fed1a10
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ include("utilities.jl")
include("linsolve.jl")
include("eigsolve.jl")
include("svdsolve.jl")
include("constructor.jl")

end # module
6 changes: 6 additions & 0 deletions ext/KrylovKitChainRulesCoreExt/constructor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
function ChainRulesCore.rrule(::Type{RecursiveVec}, A)
function RecursiveVec_pullback(ΔA)
return NoTangent(), ΔA.vecs
end
return RecursiveVec(A), RecursiveVec_pullback
end
19 changes: 11 additions & 8 deletions ext/KrylovKitChainRulesCoreExt/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ function make_eigsolve_pullback(config, f, fᴴ, x₀, howmany, which, alg_prima
n_vals = isnothing(_n_vals) ? 0 : _n_vals
n_vecs = isnothing(_n_vecs) ? 0 : _n_vecs
n = max(n_vals, n_vecs)
if n < length(vals) && vals[n + 1] == conj(vals[n])
# this can probably only happen for real problems, where it would be problematic
# to split complex conjugate pairs in solving the tangent problem
n += 1
end
# special case (can this happen?): try to maintain type stability
if n == 0
if howmany == 0
Expand All @@ -65,11 +60,16 @@ function make_eigsolve_pullback(config, f, fᴴ, x₀, howmany, which, alg_prima
∂f = construct∂f_eig(config, f, _vecs, ws)
return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg
else
ws = [vecs[1]]
ws = [zerovector(vecs[1])]
∂f = construct∂f_eig(config, f, vecs, ws)
return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg
end
end
if n < length(vals) && vals[n + 1] conj(vals[n])
# this can probably only happen for real problems, where it would be problematic
# to split complex conjugate pairs in solving the tangent problem
n += 1
end
Δvals = fill(zero(vals[1]), n)
if n_vals > 0
Δvals[1:n_vals] .= view(_Δvals, 1:n_vals)
Expand All @@ -80,10 +80,13 @@ function make_eigsolve_pullback(config, f, fᴴ, x₀, howmany, which, alg_prima
else
Δvecs = fill(zerovector(vecs[1]), n)
if n_vecs > 0
Δvecs[1:n_vecs] .= view(_Δvecs, 1:n_vecs)
for i in 1:n_vecs
if !(_Δvecs[i] isa AbstractZero)
Δvecs[i] = _Δvecs[i]
end
end
end
end

# Compute actual pullback data:
#------------------------------
ws = compute_eigsolve_pullback_data(Δvals, Δvecs, view(vals, 1:n), view(vecs, 1:n),
Expand Down

0 comments on commit fed1a10

Please sign in to comment.