From 448acfa19cfe27967915b1f726e1ef6b6ed51542 Mon Sep 17 00:00:00 2001 From: victor Date: Fri, 6 Dec 2024 11:55:30 +0100 Subject: [PATCH] bring LSMR implementation of jutho/krylovkit.jl#46 up to date --- src/KrylovKit.jl | 3 +- src/algorithms.jl | 44 +++++++++ src/linsolve/lsmr.jl | 211 +++++++++++++++++++++++++++++++++++++++++++ test/linsolve.jl | 28 ++++++ 4 files changed, 285 insertions(+), 1 deletion(-) create mode 100644 src/linsolve/lsmr.jl diff --git a/src/KrylovKit.jl b/src/KrylovKit.jl index 4f9a702..979c1fc 100644 --- a/src/KrylovKit.jl +++ b/src/KrylovKit.jl @@ -37,7 +37,7 @@ export initialize, initialize!, expand!, shrink! export ClassicalGramSchmidt, ClassicalGramSchmidt2, ClassicalGramSchmidtIR export ModifiedGramSchmidt, ModifiedGramSchmidt2, ModifiedGramSchmidtIR export LanczosIterator, ArnoldiIterator, GKLIterator -export CG, GMRES, BiCGStab, Lanczos, Arnoldi, GKL, GolubYe +export CG, GMRES, BiCGStab, Lanczos, Arnoldi, GKL, GolubYe, LSMR export KrylovDefaults, EigSorter export RecursiveVec, InnerProductVec @@ -235,6 +235,7 @@ include("linsolve/linsolve.jl") include("linsolve/cg.jl") include("linsolve/gmres.jl") include("linsolve/bicgstab.jl") +include("linsolve/lsmr.jl") # eigsolve and svdsolve include("eigsolve/eigsolve.jl") diff --git a/src/algorithms.jl b/src/algorithms.jl index b771c96..c09a3b5 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -281,6 +281,50 @@ function GMRES(; return GMRES(orth, maxiter, krylovdim, tol, verbosity) end +""" +LSMR(; orth = KrylovDefaults.orth,atol = KrylovDefaults.tol,btol = KrylovDefaults.tol,conlim = 1/KrylovDefaults.tol, + maxiter = KrylovDefaults.maxiter,krylovdim = KrylovDefaults.krylovdim,λ = 0.0,verbosity = 0) + +Represents the LSMR algorithm, which minimizes ``\\|Ax - b\\|^2 + \\|λx\\|^2`` in the Euclidean norm. +If multiple solutions exists the minimum norm solution is returned. +The method is based on the Golub-Kahan bidiagonalization process. It is +algebraically equivalent to applying MINRES to the normal equations +``(A^*A + λ^2I)x = A^*b``, but has better numerical properties, +especially if ``A`` is ill-conditioned. + +- `atol::Number = 1e-6`, `btol::Number = 1e-6`: stopping tolerances. If both are + 1.0e-9 (say), the final residual norm should be accurate to about 9 digits. + (The final `x` will usually have fewer correct digits, + depending on `cond(A)` and the size of damp). +- `conlim::Number = 1e8`: stopping tolerance. `lsmr` terminates if an estimate + of `cond(A)` exceeds conlim. For compatible systems Ax = b, + conlim could be as large as 1.0e+12 (say). For least-squares + problems, conlim should be less than 1.0e+8. + Maximum precision can be obtained by setting +- `atol` = `btol` = `conlim` = zero, but the number of iterations + may then be excessive. +""" +struct LSMR{O<:Orthogonalizer,S<:Real} <: KrylovAlgorithm + orth::O + atol::S + btol::S + conlim::S + maxiter::Int + verbosity::Int + λ::S + krylovdim::Int +end +LSMR(; orth = KrylovDefaults.orth, + atol = KrylovDefaults.tol, + btol = KrylovDefaults.tol, + conlim = 1/min(atol,btol), + maxiter = KrylovDefaults.maxiter, + krylovdim = KrylovDefaults.krylovdim, + λ = zero(atol), + verbosity = 0) = LSMR(orth,atol,btol,conlim,maxiter,verbosity,λ,krylovdim) + + + # TODO """ MINRES(; maxiter = KrylovDefaults.maxiter, tol = KrylovDefaults.tol) diff --git a/src/linsolve/lsmr.jl b/src/linsolve/lsmr.jl new file mode 100644 index 0000000..0da52c3 --- /dev/null +++ b/src/linsolve/lsmr.jl @@ -0,0 +1,211 @@ +# reference implementation https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl/blob/master/src/lsmr.jl +function linsolve(operator, b, alg::LSMR) + return linsolve(operator, b, 0 * apply_adjoint(operator, b), alg) +end; +function linsolve(operator, b, x, alg::LSMR) + u = axpby!(1, b, -1, apply_normal(operator, x)) + β = norm(u) + + # initialize GKL factorization + iter = GKLIterator(operator, u, alg.orth) + fact = initialize(iter; verbosity=alg.verbosity - 2) + numops = 2 + sizehint!(fact, alg.krylovdim) + + T = eltype(fact) + Tr = real(T) + alg.conlim > 0 ? ctol = convert(Tr, inv(alg.conlim)) : ctol = zero(Tr) + istop = 0 + + for topit in 1:(alg.maxiter)# the outermost restart loop + # Initialize variables for 1st iteration. + α = fact.αs[end] + ζbar = α * β + αbar = α + ρ = one(Tr) + ρbar = one(Tr) + cbar = one(Tr) + sbar = zero(Tr) + + # Initialize variables for estimation of ||r||. + βdd = β + βd = zero(Tr) + ρdold = one(Tr) + τtildeold = zero(Tr) + θtilde = zero(Tr) + ζ = zero(Tr) + d = zero(Tr) + + # Initialize variables for estimation of ||A|| and cond(A). + normA, condA, normx = -one(Tr), -one(Tr), -one(Tr) + normA2 = abs2(α) + maxrbar = zero(Tr) + minrbar = 1e100 + + # Items for use in stopping rules. + normb = β + normr = β + normAr = α * β + + hbar = zero(T) * x + h = one(T) * fact.V[end] + + while length(fact) < alg.krylovdim + β = normres(fact) + fact = expand!(iter, fact) + numops += 2 + + v = fact.V[end] + α = fact.αs[end] + + # Construct rotation Qhat_{k,2k+1}. + αhat = hypot(αbar, alg.λ) + chat = αbar / αhat + shat = alg.λ / αhat + + # Use a plane rotation (Q_i) to turn B_i to R_i. + ρold = ρ + ρ = hypot(αhat, β) + c = αhat / ρ + s = β / ρ + θnew = s * α + αbar = c * α + + # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar. + ρbarold = ρbar + ζold = ζ + θbar = sbar * ρ + ρtemp = cbar * ρ + ρbar = hypot(cbar * ρ, θnew) + cbar = cbar * ρ / ρbar + sbar = θnew / ρbar + ζ = cbar * ζbar + ζbar = -sbar * ζbar + + # Update h, h_hat, x. + hbar = axpby!(1, h, -θbar * ρ / (ρold * ρbarold), hbar) + h = axpby!(1, v, -θnew / ρ, h) + x = axpy!(ζ / (ρ * ρbar), hbar, x) + + ############################################################################## + ## + ## Estimate of ||r|| + ## + ############################################################################## + + # Apply rotation Qhat_{k,2k+1}. + βacute = chat * βdd + βcheck = -shat * βdd + + # Apply rotation Q_{k,k+1}. + βhat = c * βacute + βdd = -s * βacute + + # Apply rotation Qtilde_{k-1}. + θtildeold = θtilde + ρtildeold = hypot(ρdold, θbar) + ctildeold = ρdold / ρtildeold + stildeold = θbar / ρtildeold + θtilde = stildeold * ρbar + ρdold = ctildeold * ρbar + βd = -stildeold * βd + ctildeold * βhat + + τtildeold = (ζold - θtildeold * τtildeold) / ρtildeold + τd = (ζ - θtilde * τtildeold) / ρdold + d += abs2(βcheck) + normr = sqrt(d + abs2(βd - τd) + abs2(βdd)) + + # Estimate ||A||. + normA2 += abs2(β) + normA = sqrt(normA2) + normA2 += abs2(α) + + # Estimate cond(A). + maxrbar = max(maxrbar, ρbarold) + if length(fact) > 1 + minrbar = min(minrbar, ρbarold) + end + condA = max(maxrbar, ρtemp) / min(minrbar, ρtemp) + + ############################################################################## + ## + ## Test for convergence + ## + ############################################################################## + + # Compute norms for convergence testing. + normAr = abs(ζbar) + normx = norm(x) + + # Now use these norms to estimate certain other quantities, + # some of which will be small near a solution. + test1 = normr / normb + test2 = normAr / (normA * normr) + test3 = inv(condA) + + t1 = test1 / (one(Tr) + normA * normx / normb) + rtol = alg.btol + alg.atol * normA * normx / normb + # The following tests guard against extremely small values of + # atol, btol or ctol. (The user may have set any or all of + # the parameters atol, btol, conlim to 0.) + # The effect is equivalent to the normAl tests using + # atol = eps, btol = eps, conlim = 1/eps. + + if alg.verbosity > 2 + msg = "LSMR linsolve in iter $topit; step $(length(fact)-1): " + msg *= "normres = " + msg *= @sprintf("%.12e", normr) + @info msg + end + + if 1 + test3 <= 1 + istop = 6 + break + end + if 1 + test2 <= 1 + istop = 5 + break + end + if 1 + t1 <= 1 + istop = 4 + break + end + # Allow for tolerances set by the user. + if test3 <= ctol + istop = 3 + break + end + if test2 <= alg.atol + istop = 2 + break + end + if test1 <= rtol + istop = 1 + break + end + end + + u = axpby!(1, b, -1, apply_normal(operator, x)) + + istop != 0 && break + + #restart + β = norm(u) + iter = GKLIterator(operator, u, alg.orth) + fact = initialize!(iter, fact) + end + + isconv = istop ∉ (0, 3, 6) + if alg.verbosity > 0 && !isconv + @warn """LSMR linsolve finished without converging after $(alg.maxiter) iterations: + * norm of residual = $(norm(u)) + * number of operations = $numops""" + elseif alg.verbosity > 0 + if alg.verbosity > 0 + @info """LSMR linsolve converged due to istop $(istop): + * norm of residual = $(norm(u)) + * number of operations = $numops""" + end + end + return (x, ConvergenceInfo(Int(isconv), u, norm(u), alg.maxiter, numops)) +end \ No newline at end of file diff --git a/test/linsolve.jl b/test/linsolve.jl index d2a4211..f26935f 100644 --- a/test/linsolve.jl +++ b/test/linsolve.jl @@ -52,6 +52,34 @@ end end end +# Test LSMR complete +@testset "full lsmr" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + @testset for orth in (cgs2, mgs2, cgsr, mgsr) + A = rand(T, (n,n)) + v = rand(T,n); + w = rand(T,n); + alg = LSMR(orth = orth, krylovdim = 2*n, maxiter = 1, atol = 10*n*eps(real(T)),btol = 10*n*eps(real(T))) + S, info = @inferred linsolve(wrapop(A), wrapvec(v),wrapvec(w), alg) + @test info.converged > 0 + @test v≈A*unwrapvec(S)+unwrapvec(info.residual) + end + end +end +@testset "iterative lsmr" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + @testset for orth in (cgs2, mgs2, cgsr, mgsr) + A = rand(T, (N,N)) + v = rand(T,N); + w = rand(T,N); + alg = LSMR(orth = orth, krylovdim = N, maxiter = 50, atol = 10*N*eps(real(T)),btol = 10*N*eps(real(T))) + S, info = @inferred linsolve(wrapop(A), wrapvec(v),wrapvec(w), alg) + @test info.converged > 0 + @test v≈A*unwrapvec(S)+unwrapvec(info.residual) + end + end +end + # Test GMRES complete @testset "GMRES full factorization ($mode)" for mode in (:vector, :inplace, :outplace) scalartypes = mode === :vector ? (Float32, Float64, ComplexF32, ComplexF64) :