Skip to content

Commit

Permalink
For SimpleGMRES we need to reinitialize some cache when b is set again
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 5, 2023
1 parent a9b5581 commit 2afc31d
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.15.0"
version = "2.15.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
8 changes: 8 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,21 @@ end
function Base.setproperty!(cache::LinearCache, name::Symbol, x)
if name === :A
setfield!(cache, :isfresh, true)
elseif name === :b
# In case there is something that needs to be done when b is updated
update_cacheval!(cache, :b, x)
elseif name === :cacheval && cache.alg isa DefaultLinearSolver
@assert cache.cacheval isa DefaultLinearSolverInit
return setfield!(cache.cacheval, Symbol(cache.alg.alg), x)
end
setfield!(cache, name, x)
end

function update_cacheval!(cache::LinearCache, name::Symbol, x)
return update_cacheval!(cache, cache.cacheval, name, x)
end
update_cacheval!(cache, cacheval, name::Symbol, x) = cacheval

init_cacheval(alg::SciMLLinearSolveAlgorithm, args...) = nothing

function SciMLBase.init(prob::LinearProblem, args...; kwargs...)
Expand Down
7 changes: 7 additions & 0 deletions src/simplegmres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ end
warm_start::Bool
end

function update_cacheval!(cache::LinearCache, cacheval::SimpleGMRESCache, name::Symbol, x)
(name != :b || cache.isfresh) && return cacheval
vec(cacheval.w) .= vec(x)
fill!(cacheval.x, 0)
return cacheval
end

"""
(c, s, ρ) = _sym_givens(a, b)
Expand Down
12 changes: 6 additions & 6 deletions test/gpu/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ function test_interface(alg, prob1, prob2)
x2 = prob2.u0

y = solve(prob1, alg; cache_kwargs...)
@test A1 * y b1
@test Array(A1 * y) Array(b1)

cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
solve!(cache)
@test A1 * cache.u b1
@test Array(A1 * cache.u) Array(b1)

cache.A = copy(A2)
solve!(cache)
@test A2 * cache.u b1
@test Array(A2 * cache.u) Array(b1)

cache.b = copy(b2)
solve!(cache)
@test A2 * cache.u b2
@test Array(A2 * cache.u) Array(b2)

return
end
Expand All @@ -62,8 +62,8 @@ using BlockDiagonals
A = BlockDiagonal([rand(2, 2) for _ in 1:3]) |> cu
b = rand(size(A, 1)) |> cu

x1 = zero(b)
x2 = zero(b)
x1 = zero(b) |> cu
x2 = zero(b) |> cu
prob1 = LinearProblem(A, b, x1)
prob2 = LinearProblem(A, b, x2)

Expand Down

0 comments on commit 2afc31d

Please sign in to comment.