-
-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an assumptions mechanism for type-stable default help, and nonsquare #187
Merged
Merged
Changes from 5 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
881cdfb
Add an assumptions mechanism for type-stable default help, and nonsquare
ChrisRackauckas 47aadfe
format
ChrisRackauckas c18bcba
fix typos in dispatches
ChrisRackauckas 36db954
handle a few more cases
ChrisRackauckas fdd694b
format
ChrisRackauckas c691e2f
typo
ChrisRackauckas 175ffd2
correct nothing
ChrisRackauckas d7fa117
fix type decision
ChrisRackauckas d480a21
Add missing dispatch hints
ChrisRackauckas 3b029d9
add missing term in dispatch
ChrisRackauckas c475a98
add fallback for older algorithms that assumes square
ChrisRackauckas d20042b
fix ambiguities
ChrisRackauckas 52f5da5
fix downstream
ChrisRackauckas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,209 +1,136 @@ | ||
## Default algorithm | ||
|
||
# Allows A === nothing as a stand-in for dense matrix | ||
function defaultalg(A, b) | ||
if A isa DiffEqArrayOperator | ||
A = A.A | ||
end | ||
|
||
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when | ||
# it makes sense according to the benchmarks, which is dependent on | ||
# whether MKL or OpenBLAS is being used | ||
if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix | ||
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) && | ||
ArrayInterfaceCore.can_setindex(b) | ||
if length(b) <= 10 | ||
alg = GenericLUFactorization() | ||
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) && | ||
eltype(A) <: Union{Float32, Float64} | ||
alg = RFLUFactorization() | ||
#elseif A === nothing || A isa Matrix | ||
# alg = FastLUFactorization() | ||
else | ||
alg = LUFactorization() | ||
end | ||
else | ||
alg = LUFactorization() | ||
end | ||
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions) | ||
defaultalg(A.A, b, assumptions) | ||
end | ||
|
||
# These few cases ensure the choice is optimal without the | ||
# dynamic dispatching of factorize | ||
elseif A isa Tridiagonal | ||
alg = GenericFactorization(; fact_alg = lu!) | ||
elseif A isa SymTridiagonal | ||
alg = GenericFactorization(; fact_alg = ldlt!) | ||
elseif A isa SparseMatrixCSC | ||
if length(b) <= 10_000 | ||
alg = KLUFactorization() | ||
else | ||
alg = UMFPACKFactorization() | ||
end | ||
# Ambiguity handling | ||
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{nothing}) | ||
defaultalg(A.A, b, assumptions) | ||
end | ||
|
||
# This catches the cases where a factorization overload could exist | ||
# For example, BlockBandedMatrix | ||
elseif A !== nothing && ArrayInterfaceCore.isstructured(A) | ||
alg = GenericFactorization() | ||
function defaultalg(A, b, ::OperatorAssumptions{nothing}) | ||
issquare = size(A, 1) == size(A, 2) | ||
defaultalg(A, b, OperatorAssumptions(Val(issquare))) | ||
end | ||
|
||
# This catches the case where A is a CuMatrix | ||
# Which does not have LU fully defined | ||
elseif A isa GPUArraysCore.AbstractGPUArray || b isa GPUArraysCore.AbstractGPUArray | ||
if VERSION >= v"1.8-" | ||
alg = LUFactorization() | ||
else | ||
alg = QRFactorization() | ||
end | ||
function defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{true}) | ||
GenericFactorization(; fact_alg = lu!) | ||
end | ||
function defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{false}) | ||
GenericFactorization(; fact_alg = qr!) | ||
end | ||
function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{true}) | ||
GenericFactorization(; fact_alg = ldlt!) | ||
end | ||
|
||
# Not factorizable operator, default to only using A*x | ||
function defaultalg(A::SparseMatrixCSC, b, ::OperatorAssumptions{true}) | ||
if length(b) <= 10_000 | ||
KLUFactorization() | ||
else | ||
alg = KrylovJL_GMRES() | ||
UMFPACKFactorization() | ||
end | ||
alg | ||
end | ||
|
||
## Other dispatches are to decrease the dispatch cost | ||
|
||
function SciMLBase.solve(cache::LinearCache, alg::Nothing, | ||
args...; kwargs...) | ||
@unpack A = cache | ||
if A isa DiffEqArrayOperator | ||
A = A.A | ||
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, ::OperatorAssumptions{true}) | ||
if VERSION >= v"1.8-" | ||
LUFactorization() | ||
else | ||
QRFactorization() | ||
end | ||
end | ||
|
||
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when | ||
# it makes sense according to the benchmarks, which is dependent on | ||
# whether MKL or OpenBLAS is being used | ||
if A isa Matrix | ||
b = cache.b | ||
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) && | ||
ArrayInterfaceCore.can_setindex(b) | ||
if length(b) <= 10 | ||
alg = GenericLUFactorization() | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) && | ||
eltype(A) <: Union{Float32, Float64} | ||
alg = RFLUFactorization() | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
#elseif A isa Matrix | ||
# alg = FastLUFactorization() | ||
# SciMLBase.solve(cache, alg, args...; kwargs...) | ||
else | ||
alg = LUFactorization() | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
end | ||
else | ||
alg = LUFactorization() | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
end | ||
function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{true}) | ||
if VERSION >= v"1.8-" | ||
LUFactorization() | ||
else | ||
QRFactorization() | ||
end | ||
end | ||
|
||
# These few cases ensure the choice is optimal without the | ||
# dynamic dispatching of factorize | ||
elseif A isa Tridiagonal | ||
alg = GenericFactorization(; fact_alg = lu!) | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
elseif A isa SymTridiagonal | ||
alg = GenericFactorization(; fact_alg = ldlt!) | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
elseif A isa SparseMatrixCSC | ||
b = cache.b | ||
if length(b) <= 10_000 | ||
alg = KLUFactorization() | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
else | ||
alg = UMFPACKFactorization() | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
end | ||
function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b, | ||
assumptions::OperatorAssumptions) | ||
KrylovJL_GMRES() | ||
end | ||
|
||
# This catches the cases where a factorization overload could exist | ||
# For example, BlockBandedMatrix | ||
elseif ArrayInterfaceCore.isstructured(A) | ||
alg = GenericFactorization() | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
# Ambiguity handling | ||
function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b, | ||
assumptions::OperatorAssumptions{nothing}) | ||
KrylovJL_GMRES() | ||
end | ||
|
||
# This catches the case where A is a CuMatrix | ||
# Which does not have LU fully defined | ||
elseif A isa GPUArraysCore.AbstractGPUArray | ||
if VERSION >= v"1.8-" | ||
alg = LUFactorization() | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
else | ||
alg = QRFactorization() | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
end | ||
# Not factorizable operator, default to only using A*x | ||
# IterativeSolvers is faster on CPU but not GPU-compatible | ||
# Handle ambiguity | ||
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray, | ||
::OperatorAssumptions{true}) | ||
if VERSION >= v"1.8-" | ||
LUFactorization() | ||
else | ||
alg = KrylovJL_GMRES() | ||
SciMLBase.solve(cache, alg, args...; kwargs...) | ||
QRFactorization() | ||
end | ||
end | ||
|
||
function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
if A isa DiffEqArrayOperator | ||
A = A.A | ||
end | ||
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, ::OperatorAssumptions{false}) | ||
QRFactorization() | ||
end | ||
|
||
function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{false}) | ||
QRFactorization() | ||
end | ||
|
||
# Handle ambiguity | ||
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray, | ||
::OperatorAssumptions{false}) | ||
QRFactorization() | ||
end | ||
|
||
# Allows A === nothing as a stand-in for dense matrix | ||
function defaultalg(A, b, ::OperatorAssumptions{true}) | ||
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when | ||
# it makes sense according to the benchmarks, which is dependent on | ||
# whether MKL or OpenBLAS is being used | ||
if A isa Matrix | ||
if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix | ||
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) && | ||
ArrayInterfaceCore.can_setindex(b) | ||
if length(b) <= 10 | ||
alg = GenericLUFactorization() | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) && | ||
eltype(A) <: Union{Float32, Float64} | ||
alg = RFLUFactorization() | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
#elseif A isa Matrix | ||
#elseif A === nothing || A isa Matrix | ||
# alg = FastLUFactorization() | ||
# init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
else | ||
alg = LUFactorization() | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
end | ||
else | ||
alg = LUFactorization() | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
end | ||
|
||
# These few cases ensure the choice is optimal without the | ||
# dynamic dispatching of factorize | ||
elseif A isa Tridiagonal | ||
alg = GenericFactorization(; fact_alg = lu!) | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
elseif A isa SymTridiagonal | ||
alg = GenericFactorization(; fact_alg = ldlt!) | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
elseif A isa SparseMatrixCSC | ||
if length(b) <= 10_000 | ||
alg = KLUFactorization() | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
else | ||
alg = UMFPACKFactorization() | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
end | ||
|
||
# This catches the cases where a factorization overload could exist | ||
# For example, BlockBandedMatrix | ||
elseif ArrayInterfaceCore.isstructured(A) | ||
elseif A !== nothing && ArrayInterfaceCore.isstructured(A) | ||
alg = GenericFactorization() | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
|
||
# This catches the case where A is a CuMatrix | ||
# Which does not have LU fully defined | ||
elseif A isa GPUArraysCore.AbstractGPUArray | ||
if VERSION >= v"1.8-" | ||
alg = LUFactorization() | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
else | ||
alg = QRFactorization() | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
end | ||
# Not factorizable operator, default to only using A*x | ||
# IterativeSolvers is faster on CPU but not GPU-compatible | ||
else | ||
alg = KrylovJL_GMRES() | ||
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) | ||
end | ||
alg | ||
end | ||
|
||
function defaultalg(A, b, ::OperatorAssumptions{false}) | ||
QRFactorization() | ||
end | ||
|
||
## Catch high level interface | ||
|
||
function SciMLBase.solve(cache::LinearCache, alg::Nothing, | ||
args...; assumptions::OperatorAssumptions = OperatorAssumptions(), | ||
kwargs...) | ||
@unpack A, b = cache | ||
SciMLBase.solve(cache, default_alg(A, b, assumptions), args...; kwargs...) | ||
ChrisRackauckas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, | ||
verbose::Bool, assumptions::OperatorAssumptions) | ||
init_cacheval(default_alg(A, b, assumptions), A, b, u, Pl, Pr, maxiters, abstol, reltol, | ||
verbose, | ||
assumptions) | ||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems type unstable. Is that intentional, or does Julia always union split here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to have some point where a type choice is made, because then it sets the algorithm type. Things like the ODE solver would just set the assumptions at the type level.