Skip to content

Commit

Permalink
Merge pull request #371 from Circuitscape/RA/pardiso_lib
Browse files Browse the repository at this point in the history
Move Pardiso solver to lib
  • Loading branch information
ranjanan authored Dec 26, 2022
2 parents 6056353 + c12f237 commit 9d74ec2
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 52 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
Expand All @@ -27,6 +26,5 @@ ArchGDAL = "0.8, 0.9"
GZip = "0.5.1"
Graphs = "1"
IterativeSolvers = "0.9"
Pardiso = "0.5.4"
SimpleWeightedGraphs = "1.2"
julia = "1.6"
13 changes: 13 additions & 0 deletions lib/CircuitscapeMKLPardiso/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name = "CircuitscapeMKLPardiso"
uuid = "f276a096-2d87-4069-b10e-355a3501ee6e"
authors = ["ranjanan <[email protected]>"]
version = "0.1.0"

[deps]
Circuitscape = "2b7a1792-8151-5239-925d-e2b8fdfa3201"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"

[compat]
Circuitscape = "5.11"
Pardiso = "0.5"
julia = "1.6"
59 changes: 59 additions & 0 deletions lib/CircuitscapeMKLPardiso/src/CircuitscapeMKLPardiso.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module CircuitscapeMKLPardiso

using Pardiso
import Circuitscape: Solver, _check_eltype, MKLPardisoSolver, construct_cholesky_factor, solve_linear_system

mutable struct MKLPardisoFactorize
A
ps::Pardiso.MKLPardisoSolver
verbose::Bool
firsttime::Bool
end
MKLPardisoFactorize(;verbose=false) = MKLPardisoFactorize(nothing,Pardiso.MKLPardisoSolver(),verbose,true)
function (p::MKLPardisoFactorize)(x,A,b,update_matrix=false;kwargs...)
if p.firsttime
Pardiso.set_phase!(p.ps, Pardiso.ANALYSIS_NUM_FACT)
Pardiso.pardiso(p.ps, x, A, b)
p.firsttime = false
end

if update_matrix
Pardiso.set_phase!(p.ps, Pardiso.NUM_FACT)
Pardiso.pardiso(p.ps, x, A, b)
p.A = A
end

Pardiso.set_phase!(p.ps, Pardiso.SOLVE_ITERATIVE_REFINE)
Pardiso.pardiso(p.ps, x, A, b)
end

function compute_mklpardiso(str, batch_size = 5)
cfg = parse_config(str)
T = cfg["precision"] in SINGLE ? Float32 : Float64
if T == Float32
cswarn("Pardiso supports only double precision. Changing precision to double.")
T = Float64
end
V = cfg["use_64bit_indexing"] in TRUELIST ? Int64 : Int32
cfg["solver"] = "mklpardiso"
_compute(T, V, cfg)
end


_check_eltype(a, solver::MKLPardisoSolver) = a
construct_cholesky_factor(matrix, ::MKLPardisoSolver, suppress_info::Bool) =
MKLPardisoFactorize()

function solve_linear_system(factor::MKLPardisoFactorize, matrix, rhs)
lhs = similar(rhs)
mat = sparse(10eps(eltype(matrix))*I,size(matrix)...) + matrix
x = zeros(eltype(matrix), size(matrix, 1))
for i = 1:size(lhs, 2)
factor(x, mat, rhs[:,i])
@assert (norm(mat*x .- rhs[:,i]) / norm(rhs[:,i])) < 1e-6
lhs[:,i] .= x
end
lhs
end

end # module CircuitscapeMKLPardiso
1 change: 0 additions & 1 deletion src/Circuitscape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using Graphs
using SimpleWeightedGraphs
using IterativeSolvers
using GZip
using Pardiso

using LinearAlgebra
using SparseArrays
Expand Down
14 changes: 0 additions & 14 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ function _check_eltype(a, solver::CholmodSolver)
cswarn("CHOLMOD only works with double precision. Converting single precision matrix to double")
Float64.(a)
end
_check_eltype(a, solver::MKLPardisoSolver) = a

function solve(prob::GraphProblem{T,V}, solver::Union{CholmodSolver, MKLPardisoSolver}, flags,
cfg, log) where {T,V}
Expand Down Expand Up @@ -498,8 +497,6 @@ function construct_cholesky_factor(matrix, ::CholmodSolver, suppress_info::Bool)
csinfo("Time taken to construct cholesky factor = $t", suppress_info)
factor
end
construct_cholesky_factor(matrix, ::MKLPardisoSolver, suppress_info::Bool) =
MKLPardisoFactorize()


"""
Expand Down Expand Up @@ -617,17 +614,6 @@ function solve_linear_system(
v
end

function solve_linear_system(factor::MKLPardisoFactorize, matrix, rhs)
lhs = similar(rhs)
mat = sparse(10eps(eltype(matrix))*I,size(matrix)...) + matrix
x = zeros(eltype(matrix), size(matrix, 1))
for i = 1:size(lhs, 2)
factor(x, mat, rhs[:,i])
@assert (norm(mat*x .- rhs[:,i]) / norm(rhs[:,i])) < 1e-6
lhs[:,i] .= x
end
lhs
end

function solve_linear_system(factor::SuiteSparse.CHOLMOD.Factor, matrix, rhs)
lhs = factor \ rhs
Expand Down
35 changes: 0 additions & 35 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,6 @@ function compute_cholmod(str, batch_size = 5)
cfg["cholmod_batch_size"] = string(batch_size)
_compute(T, V, cfg)
end
function compute_mklpardiso(str, batch_size = 5)
cfg = parse_config(str)
T = cfg["precision"] in SINGLE ? Float32 : Float64
if T == Float32
cswarn("Pardiso supports only double precision. Changing precision to double.")
T = Float64
end
V = cfg["use_64bit_indexing"] in TRUELIST ? Int64 : Int32
cfg["solver"] = "mklpardiso"
_compute(T, V, cfg)
end

function compute_single(str)
cfg = parse_config(str)
Expand Down Expand Up @@ -445,30 +434,6 @@ function compare_node(r, x, tol = 1e-6)
sum(abs2, sortslices(r, dims=1) - sortslices(x, dims=1)) < tol
end

### Pardiso option
mutable struct MKLPardisoFactorize
A
ps::Pardiso.MKLPardisoSolver
verbose::Bool
firsttime::Bool
end
MKLPardisoFactorize(;verbose=false) = MKLPardisoFactorize(nothing,Pardiso.MKLPardisoSolver(),verbose,true)
function (p::MKLPardisoFactorize)(x,A,b,update_matrix=false;kwargs...)
if p.firsttime
Pardiso.set_phase!(p.ps, Pardiso.ANALYSIS_NUM_FACT)
Pardiso.pardiso(p.ps, x, A, b)
p.firsttime = false
end

if update_matrix
Pardiso.set_phase!(p.ps, Pardiso.NUM_FACT)
Pardiso.pardiso(p.ps, x, A, b)
p.A = A
end

Pardiso.set_phase!(p.ps, Pardiso.SOLVE_ITERATIVE_REFINE)
Pardiso.pardiso(p.ps, x, A, b)
end

# Function to calculate current for Omniscape moving window solves
function compute_omniscape_current(
Expand Down

0 comments on commit 9d74ec2

Please sign in to comment.