Skip to content

Commit

Permalink
Merge pull request #290 from SciML/normalcholesky
Browse files Browse the repository at this point in the history
Add NormalCholeskyFactorization and WellConditioned defaults
  • Loading branch information
ChrisRackauckas authored Mar 26, 2023
2 parents 8277c12 + f322e55 commit af79d26
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "1.40.0"
version = "1.41.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/basics/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ as otherwise that will need to be determined at runtime.
## I found a faster algorithm that can be used than what LinearSolve.jl chose?

What assumptions are made as part of your method? If your method only works on well-conditioned operators, then
make sure you set the `WellConditioned` assumption in the `assumptions`. See the
make sure you set the `WellConditioned` assumption in the `assumptions`. See the
[OperatorAssumptions page for more details](@ref assumptions). If using the right assumptions does not improve
the performance to the expected state, please open an issue and we will improve the default algorithm.

Expand Down
2 changes: 1 addition & 1 deletion docs/src/basics/OperatorAssumptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ IllConditioned
VeryIllConditioned
SuperIllConditioned
WellConditioned
```
```
1 change: 1 addition & 0 deletions docs/src/basics/common_solver_opts.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The following are the options these algorithms take, along with their defaults.
- `verbose`: Whether to print extra information. Defaults to `false`.
- `assumptions`: Sets the assumptions of the operator in order to effect the default
choice algorithm. See the [Operator Assumptions page for more details](@ref assumptions).

## Iterative Solver Controls

Error controls are not used by all algorithms. Specifically, direct solves always
Expand Down
6 changes: 4 additions & 2 deletions ext/LinearSolveHYPRE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ using HYPRE.LibHYPRE: HYPRE_Complex
using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
using IterativeSolvers: Identity
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
OperatorAssumptions, default_tol, init_cacheval, __issquare, set_cacheval
OperatorAssumptions, default_tol, init_cacheval, __issquare,
__conditioning, set_cacheval
using SciMLBase: LinearProblem, SciMLBase
using UnPack: @unpack
using Setfield: @set!
Expand Down Expand Up @@ -82,7 +83,8 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,

cache = LinearCache{
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), __issquare(assumptions)
typeof(Pl), typeof(Pr), typeof(reltol), __issquare(assumptions),
__conditioning(assumptions)
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters,
verbose, assumptions)
Expand Down
3 changes: 3 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ end

export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
NormalCholeskyFactorization, NormalBunchKaufmanFactorization,
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
SparspakFactorization, DiagonalFactorization

Expand All @@ -119,4 +120,6 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,

export HYPREAlgorithm

export OperatorAssumptions, OperatorCondition

end
27 changes: 16 additions & 11 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,17 @@ end
Sets the operator `A` assumptions used as part of the default algorithm
"""
struct OperatorAssumptions{issq,condition} end
function OperatorAssumptions(issquare = nothing; condition::OperatorCondition.T = OperatorCondition.IllConditioned)
struct OperatorAssumptions{issq, condition} end
function OperatorAssumptions(issquare = nothing;
condition::OperatorCondition.T = OperatorCondition.IllConditioned)
issq = something(_unwrap_val(issquare), Nothing)
condition = _unwrap_val(condition)
OperatorAssumptions{issq,condition}()
OperatorAssumptions{issq, condition}()
end
__issquare(::OperatorAssumptions{issq,condition}) where {issq,condition} = issq
__conditioning(::OperatorAssumptions{issq,condition}) where {issq,condition} = condition
__issquare(::OperatorAssumptions{issq, condition}) where {issq, condition} = issq
__conditioning(::OperatorAssumptions{issq, condition}) where {issq, condition} = condition


struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, condition}
A::TA
b::Tb
u::Tu
Expand All @@ -77,7 +77,7 @@ struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
reltol::Ttol
maxiters::Int
verbose::Bool
assumptions::OperatorAssumptions{issq}
assumptions::OperatorAssumptions{issq, condition}
end

"""
Expand Down Expand Up @@ -143,9 +143,13 @@ default_tol(::Type{<:Rational}) = 0
default_tol(::Type{<:Integer}) = 0
default_tol(::Type{Any}) = 0

function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm, Nothing},
default_alias_A(::Any, ::Any, ::Any) = false
default_alias_b(::Any, ::Any, ::Any) = false

function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
args...;
alias_A = false, alias_b = false,
alias_A = default_alias_A(alg, prob.A, prob.b),
alias_b = default_alias_b(alg, prob.A, prob.b),
abstol = default_tol(eltype(prob.A)),
reltol = default_tol(eltype(prob.A)),
maxiters::Int = length(prob.b),
Expand Down Expand Up @@ -187,7 +191,8 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
typeof(Pl),
typeof(Pr),
typeof(reltol),
__issquare(assumptions)
__issquare(assumptions),
__conditioning(assumptions)
}(A,
b,
u0,
Expand Down
69 changes: 59 additions & 10 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssump
end
end

function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions{true,OperatorCondition.IllConditioned})
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b,
assump::OperatorAssumptions{true, OperatorCondition.IllConditioned})
QRFactorization()
end

Expand All @@ -86,7 +87,8 @@ function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssump
end
end

function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{true,OperatorCondition.IllConditioned})
function defaultalg(A, b::GPUArraysCore.AbstractGPUArray,
assump::OperatorAssumptions{true, OperatorCondition.IllConditioned})
QRFactorization()
end

Expand Down Expand Up @@ -130,7 +132,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.Abstract
end

function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray,
::OperatorAssumptions{true,OperatorCondition.IllConditioned})
::OperatorAssumptions{true, OperatorCondition.IllConditioned})
QRFactorization()
end

Expand All @@ -155,17 +157,50 @@ function defaultalg(A, b, assump::OperatorAssumptions{true})
# whether MKL or OpenBLAS is being used
if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) &&
ArrayInterface.can_setindex(b) && __conditioning(assump) === OperatorCondition.IllConditioned
ArrayInterface.can_setindex(b) &&
(__conditioning(assump) === OperatorCondition.IllConditioned ||
__conditioning(assump) === OperatorCondition.WellConditioned)
if length(b) <= 10
alg = GenericLUFactorization()
pivot = @static if VERSION < v"1.7beta"
if __conditioning(assump) === OperatorCondition.IllConditioned
Val(true)
else
Val(false)
end
else
if __conditioning(assump) === OperatorCondition.IllConditioned
RowMaximum()
else
RowNonZero()
end
end
alg = GenericLUFactorization(pivot)
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
alg = RFLUFactorization()
pivot = if __conditioning(assump) === OperatorCondition.IllConditioned
Val(true)
else
Val(false)
end
alg = RFLUFactorization(; pivot = pivot)
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
else
alg = LUFactorization()
pivot = @static if VERSION < v"1.7beta"
if __conditioning(assump) === OperatorCondition.IllConditioned
Val(true)
else
Val(false)
end
else
if __conditioning(assump) === OperatorCondition.IllConditioned
RowMaximum()
else
RowNonZero()
end
end
alg = LUFactorization(pivot)
end
elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned
alg = QRFactorization()
Expand All @@ -187,20 +222,34 @@ function defaultalg(A, b, assump::OperatorAssumptions{true})
alg
end

function defaultalg(A, b, ::OperatorAssumptions{false,OperatorCondition.IllConditioned})
function defaultalg(A, b, ::OperatorAssumptions{false, OperatorCondition.WellConditioned})
NormalCholeskyFactorization()
end

function defaultalg(A, b, ::OperatorAssumptions{false, OperatorCondition.IllConditioned})
QRFactorization()
end

function defaultalg(A, b, ::OperatorAssumptions{false,OperatorCondition.VeryIllConditioned})
function defaultalg(A, b,
::OperatorAssumptions{false, OperatorCondition.VeryIllConditioned})
QRFactorization()
end

function defaultalg(A, b, ::OperatorAssumptions{false,OperatorCondition.SuperIllConditioned})
function defaultalg(A, b,
::OperatorAssumptions{false, OperatorCondition.SuperIllConditioned})
SVDFactorization(false, LinearAlgebra.QRIteration())
end

## Catch high level interface

function SciMLBase.init(prob::LinearProblem, alg::Nothing,
args...;
assumptions = OperatorAssumptions(Val(issquare(prob.A))),
kwargs...)
alg = defaultalg(prob.A, prob.b, assumptions)
SciMLBase.init(prob, alg, args...; assumptions, kwargs...)
end

function SciMLBase.solve(cache::LinearCache, alg::Nothing,
args...; assumptions::OperatorAssumptions = OperatorAssumptions(),
kwargs...)
Expand Down
78 changes: 78 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,84 @@ function SciMLBase.solve(cache::LinearCache, alg::RFLUFactorization{P, T};
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

## NormalCholeskyFactorization

struct NormalCholeskyFactorization{P} <: AbstractFactorization
pivot::P
end

function NormalCholeskyFactorization(; pivot = nothing)
if pivot === nothing
pivot = @static if VERSION < v"1.7beta"
Val(true)
else
RowMaximum()
end
end
NormalCholeskyFactorization(pivot)
end

default_alias_A(::NormalCholeskyFactorization, ::Any, ::Any) = true
default_alias_b(::NormalCholeskyFactorization, ::Any, ::Any) = true

function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
end

function SciMLBase.solve(cache::LinearCache, alg::NormalCholeskyFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
if A isa SparseMatrixCSC
fact = cholesky(Symmetric((A)' * A))
else
fact = cholesky(Symmetric((A)' * A), alg.pivot)
end
cache = set_cacheval(cache, fact)
end
if A isa SparseMatrixCSC
cache.u .= cache.cacheval \ (A' * cache.b)
y = cache.u
else
y = ldiv!(cache.u, cache.cacheval, A' * cache.b)
end
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

## NormalBunchKaufmanFactorization

struct NormalBunchKaufmanFactorization <: AbstractFactorization
rook::Bool
end

function NormalBunchKaufmanFactorization(; rook = false)
NormalBunchKaufmanFactorization(rook)
end

default_alias_A(::NormalBunchKaufmanFactorization, ::Any, ::Any) = true
default_alias_b(::NormalBunchKaufmanFactorization, ::Any, ::Any) = true

function init_cacheval(alg::NormalBunchKaufmanFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.bunchkaufman_instance(convert(AbstractMatrix, A))
end

function SciMLBase.solve(cache::LinearCache, alg::NormalBunchKaufmanFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
fact = bunchkaufman(Symmetric((A)' * A), alg.rook)
cache = set_cacheval(cache, fact)
end
y = ldiv!(cache.u, cache.cacheval, A' * cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

## DiagonalFactorization

struct DiagonalFactorization <: AbstractFactorization end
Expand Down
8 changes: 4 additions & 4 deletions src/factorization_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ function _ldiv!(x::Vector,
end

function _ldiv!(x::AbstractVector,
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
SuiteSparse.SPQR.QRSparse,
SuiteSparse.CHOLMOD.Factor}, b::AbstractVector)
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
SuiteSparse.SPQR.QRSparse,
SuiteSparse.CHOLMOD.Factor}, b::AbstractVector)
x .= A \ b
end
end
9 changes: 9 additions & 0 deletions src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ function KrylovJL(args...; KrylovAlg = Krylov.gmres!,
args, kwargs)
end

default_alias_A(::KrylovJL, ::Any, ::Any) = true
default_alias_b(::KrylovJL, ::Any, ::Any) = true

function KrylovJL_CG(args...; kwargs...)
KrylovJL(args...; KrylovAlg = Krylov.cg!, kwargs...)
end
Expand Down Expand Up @@ -205,6 +208,9 @@ function IterativeSolversJL(args...;
args, kwargs)
end

default_alias_A(::IterativeSolversJL, ::Any, ::Any) = true
default_alias_b(::IterativeSolversJL, ::Any, ::Any) = true

function IterativeSolversJL_CG(args...; kwargs...)
IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.cg_iterator!,
Expand Down Expand Up @@ -312,6 +318,9 @@ function KrylovKitJL(args...;
return KrylovKitJL(KrylovAlg, gmres_restart, args, kwargs)
end

default_alias_A(::KrylovKitJL, ::Any, ::Any) = true
default_alias_b(::KrylovKitJL, ::Any, ::Any) = true

function KrylovKitJL_CG(args...; kwargs...)
KrylovKitJL(args...; KrylovAlg = KrylovKit.CG, kwargs..., isposdef = true)
end
Expand Down
Loading

0 comments on commit af79d26

Please sign in to comment.