diff --git a/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl b/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl index bf31bf0..64a67b6 100644 --- a/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl +++ b/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl @@ -11,5 +11,6 @@ include("utilities.jl") include("linsolve.jl") include("eigsolve.jl") include("svdsolve.jl") +include("constructor.jl") end # module diff --git a/ext/KrylovKitChainRulesCoreExt/constructor.jl b/ext/KrylovKitChainRulesCoreExt/constructor.jl new file mode 100644 index 0000000..04c13ee --- /dev/null +++ b/ext/KrylovKitChainRulesCoreExt/constructor.jl @@ -0,0 +1,6 @@ +function ChainRulesCore.rrule(::Type{RecursiveVec}, A) + function RecursiveVec_pullback(ΔA) + return NoTangent(), ΔA.vecs + end + return RecursiveVec(A), RecursiveVec_pullback +end diff --git a/ext/KrylovKitChainRulesCoreExt/eigsolve.jl b/ext/KrylovKitChainRulesCoreExt/eigsolve.jl index b597135..e1c3131 100644 --- a/ext/KrylovKitChainRulesCoreExt/eigsolve.jl +++ b/ext/KrylovKitChainRulesCoreExt/eigsolve.jl @@ -51,11 +51,6 @@ function make_eigsolve_pullback(config, f, fᴴ, x₀, howmany, which, alg_prima n_vals = isnothing(_n_vals) ? 0 : _n_vals n_vecs = isnothing(_n_vecs) ? 0 : _n_vecs n = max(n_vals, n_vecs) - if n < length(vals) && vals[n + 1] == conj(vals[n]) - # this can probably only happen for real problems, where it would be problematic - # to split complex conjugate pairs in solving the tangent problem - n += 1 - end # special case (can this happen?): try to maintain type stability if n == 0 if howmany == 0 @@ -65,11 +60,16 @@ function make_eigsolve_pullback(config, f, fᴴ, x₀, howmany, which, alg_prima ∂f = construct∂f_eig(config, f, _vecs, ws) return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg else - ws = [vecs[1]] + ws = [zerovector(vecs[1])] ∂f = construct∂f_eig(config, f, vecs, ws) return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg end end + if n < length(vals) && vals[n + 1] ≈ conj(vals[n]) + # this can probably only happen for real problems, where it would be problematic + # to split complex conjugate pairs in solving the tangent problem + n += 1 + end Δvals = fill(zero(vals[1]), n) if n_vals > 0 Δvals[1:n_vals] .= view(_Δvals, 1:n_vals) @@ -80,10 +80,13 @@ function make_eigsolve_pullback(config, f, fᴴ, x₀, howmany, which, alg_prima else Δvecs = fill(zerovector(vecs[1]), n) if n_vecs > 0 - Δvecs[1:n_vecs] .= view(_Δvecs, 1:n_vecs) + for i in 1:n_vecs + if !(_Δvecs[i] isa AbstractZero) + Δvecs[i] = _Δvecs[i] + end + end end end - # Compute actual pullback data: #------------------------------ ws = compute_eigsolve_pullback_data(Δvals, Δvecs, view(vals, 1:n), view(vecs, 1:n),