Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an assumptions mechanism for type-stable default help, and nonsquare #187

Merged
merged 13 commits into from
Aug 28, 2022
14 changes: 7 additions & 7 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ makedocs(sitename = "LinearSolve.jl",
authors = "Chris Rackauckas",
modules = [LinearSolve, LinearSolve.SciMLBase],
clean = true, doctest = false,
strict=[
:doctest,
:linkcheck,
:parse_error,
:example_block,
# Other available options are
# :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block
strict = [
:doctest,
:linkcheck,
:parse_error,
:example_block,
# Other available options are
# :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block
],
format = Documenter.HTML(analytics = "UA-90474609-3",
assets = ["assets/favicon.ico"],
Expand Down
22 changes: 18 additions & 4 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol}
struct OperatorAssumptions{issquare} end
function OperatorAssumptions(issquare = nothing)
OperatorAssumptions{_unwrap_val(issquare)}()
end

_unwrap_val(::Val{B}) where {B} = B
_unwrap_val(B::Nothing) = Nothing
_unwrap_val(B::Bool) = B

struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare}
A::TA
b::Tb
u::Tu
Expand All @@ -12,6 +21,7 @@ struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol}
reltol::Ttol
maxiters::Int
verbose::Bool
assumptions::OperatorAssumptions{issquare}
end

"""
Expand Down Expand Up @@ -86,6 +96,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
verbose = false,
Pl = Identity(),
Pr = Identity(),
assumptions = OperatorAssumptions(),
kwargs...)
@unpack A, b, u0, p = prob

Expand All @@ -96,7 +107,8 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
fill!(u0, false)
end

cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose)
cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose,
assumptions)
isfresh = true
Tc = typeof(cacheval)

Expand All @@ -112,7 +124,8 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
Tc,
typeof(Pl),
typeof(Pr),
typeof(reltol)
typeof(reltol),
typeof(assumptions)
}(A,
b,
u0,
Expand All @@ -125,7 +138,8 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
abstol,
reltol,
maxiters,
verbose)
verbose,
assumptions)
return cache
end

Expand Down
255 changes: 91 additions & 164 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,209 +1,136 @@
## Default algorithm

# Allows A === nothing as a stand-in for dense matrix
function defaultalg(A, b)
if A isa DiffEqArrayOperator
A = A.A
end

# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
# it makes sense according to the benchmarks, which is dependent on
# whether MKL or OpenBLAS is being used
if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) &&
ArrayInterfaceCore.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
eltype(A) <: Union{Float32, Float64}
alg = RFLUFactorization()
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
else
alg = LUFactorization()
end
else
alg = LUFactorization()
end
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions)
defaultalg(A.A, b, assumptions)
end

# These few cases ensure the choice is optimal without the
# dynamic dispatching of factorize
elseif A isa Tridiagonal
alg = GenericFactorization(; fact_alg = lu!)
elseif A isa SymTridiagonal
alg = GenericFactorization(; fact_alg = ldlt!)
elseif A isa SparseMatrixCSC
if length(b) <= 10_000
alg = KLUFactorization()
else
alg = UMFPACKFactorization()
end
# Ambiguity handling
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{nothing})
defaultalg(A.A, b, assumptions)
end

# This catches the cases where a factorization overload could exist
# For example, BlockBandedMatrix
elseif A !== nothing && ArrayInterfaceCore.isstructured(A)
alg = GenericFactorization()
function defaultalg(A, b, ::OperatorAssumptions{nothing})
issquare = size(A, 1) == size(A, 2)
defaultalg(A, b, OperatorAssumptions(Val(issquare)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems type unstable. Is that intentional, or does Julia always union split here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to have some point where a type choice is made, because then it sets the algorithm type. Things like the ODE solver would just set the assumptions at the type level.

end

# This catches the case where A is a CuMatrix
# Which does not have LU fully defined
elseif A isa GPUArraysCore.AbstractGPUArray || b isa GPUArraysCore.AbstractGPUArray
if VERSION >= v"1.8-"
alg = LUFactorization()
else
alg = QRFactorization()
end
function defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{true})
GenericFactorization(; fact_alg = lu!)
end
function defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{false})
GenericFactorization(; fact_alg = qr!)
end
function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{true})
GenericFactorization(; fact_alg = ldlt!)
end

# Not factorizable operator, default to only using A*x
function defaultalg(A::SparseMatrixCSC, b, ::OperatorAssumptions{true})
if length(b) <= 10_000
KLUFactorization()
else
alg = KrylovJL_GMRES()
UMFPACKFactorization()
end
alg
end

## Other dispatches are to decrease the dispatch cost

function SciMLBase.solve(cache::LinearCache, alg::Nothing,
args...; kwargs...)
@unpack A = cache
if A isa DiffEqArrayOperator
A = A.A
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, ::OperatorAssumptions{true})
if VERSION >= v"1.8-"
LUFactorization()
else
QRFactorization()
end
end

# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
# it makes sense according to the benchmarks, which is dependent on
# whether MKL or OpenBLAS is being used
if A isa Matrix
b = cache.b
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) &&
ArrayInterfaceCore.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
eltype(A) <: Union{Float32, Float64}
alg = RFLUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
#elseif A isa Matrix
# alg = FastLUFactorization()
# SciMLBase.solve(cache, alg, args...; kwargs...)
else
alg = LUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
end
else
alg = LUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
end
function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{true})
if VERSION >= v"1.8-"
LUFactorization()
else
QRFactorization()
end
end

# These few cases ensure the choice is optimal without the
# dynamic dispatching of factorize
elseif A isa Tridiagonal
alg = GenericFactorization(; fact_alg = lu!)
SciMLBase.solve(cache, alg, args...; kwargs...)
elseif A isa SymTridiagonal
alg = GenericFactorization(; fact_alg = ldlt!)
SciMLBase.solve(cache, alg, args...; kwargs...)
elseif A isa SparseMatrixCSC
b = cache.b
if length(b) <= 10_000
alg = KLUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
else
alg = UMFPACKFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
end
function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
assumptions::OperatorAssumptions)
KrylovJL_GMRES()
end

# This catches the cases where a factorization overload could exist
# For example, BlockBandedMatrix
elseif ArrayInterfaceCore.isstructured(A)
alg = GenericFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
# Ambiguity handling
function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
assumptions::OperatorAssumptions{nothing})
KrylovJL_GMRES()
end

# This catches the case where A is a CuMatrix
# Which does not have LU fully defined
elseif A isa GPUArraysCore.AbstractGPUArray
if VERSION >= v"1.8-"
alg = LUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
else
alg = QRFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
end
# Not factorizable operator, default to only using A*x
# IterativeSolvers is faster on CPU but not GPU-compatible
# Handle ambiguity
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray,
::OperatorAssumptions{true})
if VERSION >= v"1.8-"
LUFactorization()
else
alg = KrylovJL_GMRES()
SciMLBase.solve(cache, alg, args...; kwargs...)
QRFactorization()
end
end

function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
if A isa DiffEqArrayOperator
A = A.A
end
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, ::OperatorAssumptions{false})
QRFactorization()
end

function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{false})
QRFactorization()
end

# Handle ambiguity
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray,
::OperatorAssumptions{false})
QRFactorization()
end

# Allows A === nothing as a stand-in for dense matrix
function defaultalg(A, b, ::OperatorAssumptions{true})
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
# it makes sense according to the benchmarks, which is dependent on
# whether MKL or OpenBLAS is being used
if A isa Matrix
if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) &&
ArrayInterfaceCore.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
eltype(A) <: Union{Float32, Float64}
alg = RFLUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
#elseif A isa Matrix
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
# init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
else
alg = LUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
end
else
alg = LUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
end

# These few cases ensure the choice is optimal without the
# dynamic dispatching of factorize
elseif A isa Tridiagonal
alg = GenericFactorization(; fact_alg = lu!)
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
elseif A isa SymTridiagonal
alg = GenericFactorization(; fact_alg = ldlt!)
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
elseif A isa SparseMatrixCSC
if length(b) <= 10_000
alg = KLUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
else
alg = UMFPACKFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
end

# This catches the cases where a factorization overload could exist
# For example, BlockBandedMatrix
elseif ArrayInterfaceCore.isstructured(A)
elseif A !== nothing && ArrayInterfaceCore.isstructured(A)
alg = GenericFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)

# This catches the case where A is a CuMatrix
# Which does not have LU fully defined
elseif A isa GPUArraysCore.AbstractGPUArray
if VERSION >= v"1.8-"
alg = LUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
else
alg = QRFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
end
# Not factorizable operator, default to only using A*x
# IterativeSolvers is faster on CPU but not GPU-compatible
else
alg = KrylovJL_GMRES()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
end
alg
end

function defaultalg(A, b, ::OperatorAssumptions{false})
QRFactorization()
end

## Catch high level interface

function SciMLBase.solve(cache::LinearCache, alg::Nothing,
args...; assumptions::OperatorAssumptions = OperatorAssumptions(),
kwargs...)
@unpack A, b = cache
SciMLBase.solve(cache, default_alg(A, b, assumptions), args...; kwargs...)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
end

function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
init_cacheval(default_alg(A, b, assumptions), A, b, u, Pl, Pr, maxiters, abstol, reltol,
verbose,
assumptions)
end
Loading