Skip to content

Commit

Permalink
Fix default solve algorithm handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Nov 24, 2021
1 parent cdfe484 commit c4f8ed5
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 8 deletions.
10 changes: 7 additions & 3 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ function set_cacheval(cache::LinearCache, alg_cache)
return cache
end

init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
init_cacheval(alg::Union{SciMLLinearSolveAlgorithm,Nothing}, A, b, u) = nothing

function SciMLBase.init(prob::LinearProblem, alg, args...;
SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)

function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm,Nothing}, args...;
alias_A = false, alias_b = false,
kwargs...,
)
Expand Down Expand Up @@ -83,7 +85,9 @@ function SciMLBase.init(prob::LinearProblem, alg, args...;
return cache
end

SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
SciMLBase.solve(prob::LinearProblem, args...; kwargs...) = solve(init(prob, nothing, args...; kwargs...))

SciMLBase.solve(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm,Nothing},
args...; kwargs...) = solve(init(prob, alg, args...; kwargs...))

SciMLBase.solve(cache::LinearCache, args...; kwargs...) =
Expand Down
8 changes: 6 additions & 2 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
args...; kwargs...)
@unpack A = cache
if A isa DiffEqArrayOperator
A = A.A
end

if A isa Matrix
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 ||
if ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 ||
(isopenblas() && size(A,1) <= 500)
)
alg = GenericFactorization(;fact_alg=:(RecursiveFactorization.lu!))
alg = GenericFactorization(;fact_alg=RecursiveFactorization.lu!)
SciMLBase.solve(cache, alg, args...; kwargs...)
else
alg = LUFactorization()
Expand Down
16 changes: 15 additions & 1 deletion src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ end
function init_cacheval(alg::LUFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("LU is not defined for $(typeof(A))")

if A isa AbstractDiffEqOperator
A = A.A
end
fact = lu!(A, alg.pivot)
return fact
end
Expand All @@ -49,7 +53,10 @@ function init_cacheval(alg::QRFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("QR is not defined for $(typeof(A))")

fact = qr!(A.A, alg.pivot; blocksize = alg.blocksize)
if A isa AbstractDiffEqOperator
A = A.A
end
fact = qr!(A, alg.pivot; blocksize = alg.blocksize)
return fact
end

Expand All @@ -66,6 +73,10 @@ function init_cacheval(alg::SVDFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("SVD is not defined for $(typeof(A))")

if A isa AbstractDiffEqOperator
A = A.A
end

fact = svd!(A; full = alg.full, alg = alg.alg)
return fact
end
Expand All @@ -83,6 +94,9 @@ function init_cacheval(alg::GenericFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("GenericFactorization is not defined for $(typeof(A))")

if A isa AbstractDiffEqOperator
A = A.A
end
fact = alg.fact_alg(A)
return fact
end
11 changes: 9 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,19 @@ function test_interface(alg, prob1, prob2)
return
end

@testset "Default Linear Solver" begin
test_interface(nothing, prob1, prob2)

A1 = prob1.A; b1 = prob1.b; x1 = prob1.u0
y = solve(prob1)
@test A1 * y b1
end

@testset "Concrete Factorizations" begin
for alg in (
LUFactorization(),
QRFactorization(),
SVDFactorization(),
#nothing
SVDFactorization()
)
@testset "$alg" begin
test_interface(alg, prob1, prob2)
Expand Down

0 comments on commit c4f8ed5

Please sign in to comment.