Skip to content

Commit

Permalink
bring LSMR implementation of #46 up to date
Browse files Browse the repository at this point in the history
  • Loading branch information
VictorVanthilt committed Dec 6, 2024
1 parent 4d2a06f commit 448acfa
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/KrylovKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export initialize, initialize!, expand!, shrink!
export ClassicalGramSchmidt, ClassicalGramSchmidt2, ClassicalGramSchmidtIR
export ModifiedGramSchmidt, ModifiedGramSchmidt2, ModifiedGramSchmidtIR
export LanczosIterator, ArnoldiIterator, GKLIterator
export CG, GMRES, BiCGStab, Lanczos, Arnoldi, GKL, GolubYe
export CG, GMRES, BiCGStab, Lanczos, Arnoldi, GKL, GolubYe, LSMR
export KrylovDefaults, EigSorter
export RecursiveVec, InnerProductVec

Expand Down Expand Up @@ -235,6 +235,7 @@ include("linsolve/linsolve.jl")
include("linsolve/cg.jl")
include("linsolve/gmres.jl")
include("linsolve/bicgstab.jl")
include("linsolve/lsmr.jl")

# eigsolve and svdsolve
include("eigsolve/eigsolve.jl")
Expand Down
44 changes: 44 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,50 @@ function GMRES(;
return GMRES(orth, maxiter, krylovdim, tol, verbosity)
end

"""
LSMR(; orth = KrylovDefaults.orth,atol = KrylovDefaults.tol,btol = KrylovDefaults.tol,conlim = 1/KrylovDefaults.tol,
maxiter = KrylovDefaults.maxiter,krylovdim = KrylovDefaults.krylovdim,λ = 0.0,verbosity = 0)
Represents the LSMR algorithm, which minimizes ``\\|Ax - b\\|^2 + \\|λx\\|^2`` in the Euclidean norm.
If multiple solutions exists the minimum norm solution is returned.
The method is based on the Golub-Kahan bidiagonalization process. It is
algebraically equivalent to applying MINRES to the normal equations
``(A^*A + λ^2I)x = A^*b``, but has better numerical properties,
especially if ``A`` is ill-conditioned.
- `atol::Number = 1e-6`, `btol::Number = 1e-6`: stopping tolerances. If both are
1.0e-9 (say), the final residual norm should be accurate to about 9 digits.
(The final `x` will usually have fewer correct digits,
depending on `cond(A)` and the size of damp).
- `conlim::Number = 1e8`: stopping tolerance. `lsmr` terminates if an estimate
of `cond(A)` exceeds conlim. For compatible systems Ax = b,
conlim could be as large as 1.0e+12 (say). For least-squares
problems, conlim should be less than 1.0e+8.
Maximum precision can be obtained by setting
- `atol` = `btol` = `conlim` = zero, but the number of iterations
may then be excessive.
"""
struct LSMR{O<:Orthogonalizer,S<:Real} <: KrylovAlgorithm
orth::O
atol::S
btol::S
conlim::S
maxiter::Int
verbosity::Int
λ::S
krylovdim::Int
end
LSMR(; orth = KrylovDefaults.orth,
atol = KrylovDefaults.tol,
btol = KrylovDefaults.tol,
conlim = 1/min(atol,btol),
maxiter = KrylovDefaults.maxiter,
krylovdim = KrylovDefaults.krylovdim,
λ = zero(atol),
verbosity = 0) = LSMR(orth,atol,btol,conlim,maxiter,verbosity,λ,krylovdim)



# TODO
"""
MINRES(; maxiter = KrylovDefaults.maxiter, tol = KrylovDefaults.tol)
Expand Down
211 changes: 211 additions & 0 deletions src/linsolve/lsmr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# reference implementation https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl/blob/master/src/lsmr.jl
function linsolve(operator, b, alg::LSMR)
return linsolve(operator, b, 0 * apply_adjoint(operator, b), alg)
end;
function linsolve(operator, b, x, alg::LSMR)
u = axpby!(1, b, -1, apply_normal(operator, x))
β = norm(u)

# initialize GKL factorization
iter = GKLIterator(operator, u, alg.orth)
fact = initialize(iter; verbosity=alg.verbosity - 2)
numops = 2
sizehint!(fact, alg.krylovdim)

T = eltype(fact)
Tr = real(T)
alg.conlim > 0 ? ctol = convert(Tr, inv(alg.conlim)) : ctol = zero(Tr)
istop = 0

for topit in 1:(alg.maxiter)# the outermost restart loop
# Initialize variables for 1st iteration.
α = fact.αs[end]
ζbar = α * β
αbar = α
ρ = one(Tr)
ρbar = one(Tr)
cbar = one(Tr)
sbar = zero(Tr)

# Initialize variables for estimation of ||r||.
βdd = β
βd = zero(Tr)
ρdold = one(Tr)
τtildeold = zero(Tr)
θtilde = zero(Tr)
ζ = zero(Tr)
d = zero(Tr)

# Initialize variables for estimation of ||A|| and cond(A).
normA, condA, normx = -one(Tr), -one(Tr), -one(Tr)
normA2 = abs2(α)
maxrbar = zero(Tr)
minrbar = 1e100

# Items for use in stopping rules.
normb = β
normr = β
normAr = α * β

hbar = zero(T) * x
h = one(T) * fact.V[end]

while length(fact) < alg.krylovdim
β = normres(fact)
fact = expand!(iter, fact)
numops += 2

v = fact.V[end]
α = fact.αs[end]

# Construct rotation Qhat_{k,2k+1}.
αhat = hypot(αbar, alg.λ)
chat = αbar / αhat
shat = alg.λ / αhat

# Use a plane rotation (Q_i) to turn B_i to R_i.
ρold = ρ
ρ = hypot(αhat, β)
c = αhat / ρ
s = β / ρ
θnew = s * α
αbar = c * α

# Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar.
ρbarold = ρbar
ζold = ζ
θbar = sbar * ρ
ρtemp = cbar * ρ
ρbar = hypot(cbar * ρ, θnew)
cbar = cbar * ρ / ρbar
sbar = θnew / ρbar
ζ = cbar * ζbar
ζbar = -sbar * ζbar

# Update h, h_hat, x.
hbar = axpby!(1, h, -θbar * ρ / (ρold * ρbarold), hbar)
h = axpby!(1, v, -θnew / ρ, h)
x = axpy!/* ρbar), hbar, x)

##############################################################################
##
## Estimate of ||r||
##
##############################################################################

# Apply rotation Qhat_{k,2k+1}.
βacute = chat * βdd
βcheck = -shat * βdd

# Apply rotation Q_{k,k+1}.
βhat = c * βacute
βdd = -s * βacute

# Apply rotation Qtilde_{k-1}.
θtildeold = θtilde
ρtildeold = hypot(ρdold, θbar)
ctildeold = ρdold / ρtildeold
stildeold = θbar / ρtildeold
θtilde = stildeold * ρbar
ρdold = ctildeold * ρbar
βd = -stildeold * βd + ctildeold * βhat

τtildeold = (ζold - θtildeold * τtildeold) / ρtildeold
τd =- θtilde * τtildeold) / ρdold
d += abs2(βcheck)
normr = sqrt(d + abs2(βd - τd) + abs2(βdd))

# Estimate ||A||.
normA2 += abs2(β)
normA = sqrt(normA2)
normA2 += abs2(α)

# Estimate cond(A).
maxrbar = max(maxrbar, ρbarold)
if length(fact) > 1
minrbar = min(minrbar, ρbarold)
end
condA = max(maxrbar, ρtemp) / min(minrbar, ρtemp)

##############################################################################
##
## Test for convergence
##
##############################################################################

# Compute norms for convergence testing.
normAr = abs(ζbar)
normx = norm(x)

# Now use these norms to estimate certain other quantities,
# some of which will be small near a solution.
test1 = normr / normb
test2 = normAr / (normA * normr)
test3 = inv(condA)

t1 = test1 / (one(Tr) + normA * normx / normb)
rtol = alg.btol + alg.atol * normA * normx / normb
# The following tests guard against extremely small values of
# atol, btol or ctol. (The user may have set any or all of
# the parameters atol, btol, conlim to 0.)
# The effect is equivalent to the normAl tests using
# atol = eps, btol = eps, conlim = 1/eps.

if alg.verbosity > 2
msg = "LSMR linsolve in iter $topit; step $(length(fact)-1): "
msg *= "normres = "
msg *= @sprintf("%.12e", normr)
@info msg
end

if 1 + test3 <= 1
istop = 6
break
end
if 1 + test2 <= 1
istop = 5
break
end
if 1 + t1 <= 1
istop = 4
break
end
# Allow for tolerances set by the user.
if test3 <= ctol
istop = 3
break
end
if test2 <= alg.atol
istop = 2
break
end
if test1 <= rtol
istop = 1
break
end
end

u = axpby!(1, b, -1, apply_normal(operator, x))

istop != 0 && break

#restart
β = norm(u)
iter = GKLIterator(operator, u, alg.orth)
fact = initialize!(iter, fact)
end

isconv = istop (0, 3, 6)
if alg.verbosity > 0 && !isconv
@warn """LSMR linsolve finished without converging after $(alg.maxiter) iterations:
* norm of residual = $(norm(u))
* number of operations = $numops"""
elseif alg.verbosity > 0
if alg.verbosity > 0
@info """LSMR linsolve converged due to istop $(istop):
* norm of residual = $(norm(u))
* number of operations = $numops"""
end
end
return (x, ConvergenceInfo(Int(isconv), u, norm(u), alg.maxiter, numops))
end
28 changes: 28 additions & 0 deletions test/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,34 @@ end
end
end

# Test LSMR complete
@testset "full lsmr" begin
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset for orth in (cgs2, mgs2, cgsr, mgsr)
A = rand(T, (n,n))
v = rand(T,n);
w = rand(T,n);
alg = LSMR(orth = orth, krylovdim = 2*n, maxiter = 1, atol = 10*n*eps(real(T)),btol = 10*n*eps(real(T)))
S, info = @inferred linsolve(wrapop(A), wrapvec(v),wrapvec(w), alg)
@test info.converged > 0
@test vA*unwrapvec(S)+unwrapvec(info.residual)
end
end
end
@testset "iterative lsmr" begin
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset for orth in (cgs2, mgs2, cgsr, mgsr)
A = rand(T, (N,N))
v = rand(T,N);
w = rand(T,N);
alg = LSMR(orth = orth, krylovdim = N, maxiter = 50, atol = 10*N*eps(real(T)),btol = 10*N*eps(real(T)))
S, info = @inferred linsolve(wrapop(A), wrapvec(v),wrapvec(w), alg)
@test info.converged > 0
@test vA*unwrapvec(S)+unwrapvec(info.residual)
end
end
end

# Test GMRES complete
@testset "GMRES full factorization ($mode)" for mode in (:vector, :inplace, :outplace)
scalartypes = mode === :vector ? (Float32, Float64, ComplexF32, ComplexF64) :
Expand Down

0 comments on commit 448acfa

Please sign in to comment.