Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix default solve algorithm handling #25

Merged
merged 2 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ function set_cacheval(cache::LinearCache, alg_cache)
return cache
end

init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
init_cacheval(alg::Union{SciMLLinearSolveAlgorithm,Nothing}, A, b, u) = nothing

function SciMLBase.init(prob::LinearProblem, alg, args...;
SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)

function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm,Nothing}, args...;
alias_A = false, alias_b = false,
kwargs...,
)
Expand Down Expand Up @@ -83,7 +85,9 @@ function SciMLBase.init(prob::LinearProblem, alg, args...;
return cache
end

SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
SciMLBase.solve(prob::LinearProblem, args...; kwargs...) = solve(init(prob, nothing, args...; kwargs...))

SciMLBase.solve(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm,Nothing},
args...; kwargs...) = solve(init(prob, alg, args...; kwargs...))

SciMLBase.solve(cache::LinearCache, args...; kwargs...) =
Expand Down
8 changes: 6 additions & 2 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
args...; kwargs...)
@unpack A = cache
if A isa DiffEqArrayOperator
A = A.A
end

if A isa Matrix
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 ||
if ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 ||
(isopenblas() && size(A,1) <= 500)
)
alg = GenericFactorization(;fact_alg=:(RecursiveFactorization.lu!))
alg = GenericFactorization(;fact_alg=RecursiveFactorization.lu!)
SciMLBase.solve(cache, alg, args...; kwargs...)
else
alg = LUFactorization()
Expand Down
18 changes: 16 additions & 2 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function SciMLBase.solve(cache::LinearCache, alg::AbstractFactorization)
cache = set_cacheval(cache, fact)
end

ldiv!(cache.u,cache.cacheval, cache.b)
ldiv!(cache.u, cache.cacheval, cache.b)
end

## LUFactorization
Expand All @@ -25,6 +25,10 @@ end
function init_cacheval(alg::LUFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("LU is not defined for $(typeof(A))")

if A isa AbstractDiffEqOperator
A = A.A
end
fact = lu!(A, alg.pivot)
return fact
end
Expand All @@ -49,7 +53,10 @@ function init_cacheval(alg::QRFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("QR is not defined for $(typeof(A))")

fact = qr!(A.A, alg.pivot; blocksize = alg.blocksize)
if A isa AbstractDiffEqOperator
A = A.A
end
fact = qr!(A, alg.pivot; blocksize = alg.blocksize)
return fact
end

Expand All @@ -66,6 +73,10 @@ function init_cacheval(alg::SVDFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("SVD is not defined for $(typeof(A))")

if A isa AbstractDiffEqOperator
A = A.A
end

fact = svd!(A; full = alg.full, alg = alg.alg)
return fact
end
Expand All @@ -83,6 +94,9 @@ function init_cacheval(alg::GenericFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("GenericFactorization is not defined for $(typeof(A))")

if A isa AbstractDiffEqOperator
A = A.A
end
fact = alg.fact_alg(A)
return fact
end
33 changes: 30 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearSolve, LinearAlgebra
using LinearSolve, LinearAlgebra, SparseArrays
using Test

n = 8
Expand Down Expand Up @@ -32,12 +32,39 @@ function test_interface(alg, prob1, prob2)
return
end

@testset "Default Linear Solver" begin
test_interface(nothing, prob1, prob2)

A1 = prob1.A; b1 = prob1.b; x1 = prob1.u0
y = solve(prob1)
@test A1 * y ≈ b1

_prob = LinearProblem(SymTridiagonal(A1.A), b1; u0=x1)
y = solve(prob1)
@test A1 * y ≈ b1

_prob = LinearProblem(Tridiagonal(A1.A), b1; u0=x1)
y = solve(prob1)
@test A1 * y ≈ b1

_prob = LinearProblem(Symmetric(A1.A), b1; u0=x1)
y = solve(prob1)
@test A1 * y ≈ b1

_prob = LinearProblem(Hermitian(A1.A), b1; u0=x1)
y = solve(prob1)
@test A1 * y ≈ b1

_prob = LinearProblem(sparse(A1.A), b1; u0=x1)
y = solve(prob1)
@test A1 * y ≈ b1
end

@testset "Concrete Factorizations" begin
for alg in (
LUFactorization(),
QRFactorization(),
SVDFactorization(),
#nothing
SVDFactorization()
)
@testset "$alg" begin
test_interface(alg, prob1, prob2)
Expand Down